Skip to content

Instantly share code, notes, and snippets.

@bombol
Created June 27, 2016 20:48
Show Gist options
  • Save bombol/72cfa32417f81d5d79193a50ea588bb4 to your computer and use it in GitHub Desktop.
Save bombol/72cfa32417f81d5d79193a50ea588bb4 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 218,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"\n",
"import numpy as np\n",
"import theano\n",
"import theano.tensor as T\n",
"import lasagne\n",
" \n",
"from sklearn.datasets import load_iris\n",
"from sklearn.cross_validation import train_test_split\n",
"\n",
"# Load and store features as X and targets as y\n",
"iris = load_iris()\n",
"X = np.asarray(iris.data, dtype='float32')\n",
"y = np.asarray(iris.target, dtype='int8')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Here we set one of the conditionals to True if we wish to do binary classification"
]
},
{
"cell_type": "code",
"execution_count": 219,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# default is to do 3-class classification\n",
"y_ = y; X_ = X\n",
"nonlin = lasagne.nonlinearities.softmax \n",
"loss_fn = lasagne.objectives.categorical_crossentropy\n",
"units = 3\n",
"\n",
"# IF THIS IS TRUE, WE ELIMINATE THE CLASS WHERE Y=2 AND DO BINARY CLASSIFICATION\n",
"# WITH SIGMOID AND BINARY CROSS ENTROPY\n",
"binary = True #False\n",
"\n",
"if (binary):\n",
" y_ = y[y<2].reshape(-1,1)\n",
" X_ = X[y<2]\n",
" units = 1\n",
" nonlin = lasagne.nonlinearities.sigmoid\n",
" loss_fn = lasagne.objectives.binary_crossentropy\n",
" \n",
"# IF THIS IS TRUE, WE ELIMINATE THE CLASS WHERE Y=2 AND DO BINARY CLASSIFICATION\n",
"# WITH SOFTMAX AND CATEGORICAL CROSS ENTROPY\n",
"if False:\n",
" y_ = y[y<2]\n",
" X_ = X[y<2]\n",
" units = 2\n",
" #nonlin = lasagne.nonlinearities.softmax\n",
" #loss_fn = lasagne.objectives.categorical_crossentropy\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 220,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"array([[0],\n",
" [0],\n",
" [0],\n",
" [1],\n",
" [1]], dtype=int8)"
]
},
"execution_count": 220,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"\n",
"X_train, X_test, y_train, y_test = train_test_split(X_, y_, train_size=0.8)\n",
"X_val = X_test[:15]\n",
"y_val = y_test[:15]\n",
"X_test = X_test[15:]\n",
"y_test = y_test[15:]\n",
"\n",
"y_test # print y values to make sure they are the proper form"
]
},
{
"cell_type": "code",
"execution_count": 222,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def batch_gen(X, y, N):\n",
" while True:\n",
" idx = np.random.choice(len(y), N)\n",
" yield X[idx].astype('float32'), y[idx].astype('int8')"
]
},
{
"cell_type": "code",
"execution_count": 229,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# Network\n",
"\n",
"# 4-dim vector on input\n",
"l_in = lasagne.layers.InputLayer((None, 4))\n",
"# 3-dim ivector on output\n",
"l_out = lasagne.layers.DenseLayer(l_in, num_units=units, nonlinearity=nonlin)\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 226,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"\n",
"X_sym = T.matrix('X')\n",
"y_sym = T.ivector('y')\n",
"\n",
"# Expression for the output distribution\n",
"output = lasagne.layers.get_output(l_out, X_sym)\n",
"pred = output.argmax(-1) \n",
"\n",
"if (binary):\n",
" y_sym = T.matrix()\n",
" pred = (output > 0.5)\n",
"\n",
"# Loss function\n",
"loss = T.mean(loss_fn(output, y_sym))\n",
"acc = T.mean(T.eq(pred, y_sym))\n",
" \n",
"# We retrieve the parameters\n",
"params = lasagne.layers.get_all_params(l_out)\n",
" \n",
"# Compute the gradient of the loss function with respect to the parameters.\n",
"# The stochastic gradient descent updates the parameters\n",
"grad = T.grad(loss, params)\n",
"updates = lasagne.updates.sgd(grad, params, learning_rate=0.05)\n",
" \n",
"# Define a training function\n",
"f_train = theano.function([X_sym, y_sym], [loss, acc], updates=updates)\n",
" \n",
"# A validation function, similar but it doesn't alter the parameters\n",
"f_val = theano.function([X_sym, y_sym], [loss, acc])\n",
" \n",
"# Prediction function,\n",
"f_predict = theano.function([X_sym], pred)\n",
" \n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# Create a loss expression for training\n",
"prediction = lasagne.layers.get_output(l_out)\n",
"loss = lasagne.objectives.binary_crossentropy(prediction, target_var)\n",
"loss = loss.mean()\n",
"\n",
"# Create update expressions for training (Stochastic Gradient Descent (SGD))\n",
"params = lasagne.layers.get_all_params(network, trainable=True)\n",
"updates = lasagne.updates.sgd(loss, params, learning_rate=0.01)\n",
"\n",
"# Create a loss expression for validation/testing. \n",
"test_prediction = lasagne.layers.get_output(network, deterministic=True)\n",
"test_loss = lasagne.objectives.binary_crossentropy(test_prediction,target_var)\n",
"test_loss = test_loss.mean()\n",
"\n",
"# Accuracy\n",
"test_acc = lasagne.objectives.binary_accuracy(test_prediction,target_var)\n",
"test_acc = test_acc.mean()\n",
"\n",
"# Compile a function performing a training step \n",
"train_fn = theano.function([input_var, target_var], loss, updates=updates)\n",
"\n",
"# Compile a second function computing the validation loss and accuracy\n",
"val_fn = theano.function([input_var, target_var], [test_loss, test_acc])"
]
},
{
"cell_type": "code",
"execution_count": 227,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# Batch size choice and the number of batches per epoch\n",
"BATCH_SIZE = 10\n",
"N_BATCHES = len(X_train) // BATCH_SIZE\n",
"N_VAL_BATCHES = len(X_val) // BATCH_SIZE\n",
" \n",
"# Minibatch generators\n",
"train_batches = batch_gen(X_train, y_train, BATCH_SIZE)\n",
"val_batches = batch_gen(X_val, y_val, BATCH_SIZE)\n"
]
},
{
"cell_type": "code",
"execution_count": 228,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
" 0%| | 0/100 [00:00<?, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Starting training... for 8 batches\n",
"\n",
"Epoch\t|\tTime\t| Training loss\t| Training acc\t| Valid. loss\t| Valid. Accuracy | Valid/Train loss\n",
"________________________________________________________________________________________________________\n",
"1\t|\t0.004s\t|\t0.464\t|\t87.50%\t|\t0.366\t|\t100.00%\t|\t0.789\n",
"2\t|\t0.004s\t|\t0.314\t|\t100.00%\t|\t0.323\t|\t100.00%\t|\t1.028\n",
"3\t|\t0.004s\t|\t0.289\t|\t100.00%\t|\t0.266\t|\t100.00%\t|\t0.921\n",
"4\t|\t0.004s\t|\t0.255\t|\t100.00%\t|\t0.232\t|\t100.00%\t|\t0.906\n",
"5\t|\t0.004s\t|\t0.229\t|\t100.00%\t|\t0.196\t|\t100.00%\t|\t0.857\n",
"6\t|\t0.004s\t|\t0.192\t|\t100.00%\t|\t0.188\t|\t100.00%\t|\t0.981\n",
"7\t|\t0.004s\t|\t0.177\t|\t100.00%\t|\t0.164\t|\t100.00%\t|\t0.926\n",
"8\t|\t0.004s\t|\t0.153\t|\t100.00%\t|\t0.143\t|\t100.00%\t|\t0.934\n",
"9\t|\t0.004s\t|\t0.159\t|\t100.00%\t|\t0.137\t|\t100.00%\t|\t0.861\n",
"10\t|\t0.004s\t|\t0.130\t|\t100.00%\t|\t0.138\t|\t100.00%\t|\t1.066\n",
"11\t|\t0.004s\t|\t0.120\t|\t100.00%\t|\t0.127\t|\t100.00%\t|\t1.056\n",
"12\t|\t0.004s\t|\t0.117\t|\t100.00%\t|\t0.117\t|\t100.00%\t|\t0.994\n",
"13\t|\t0.004s\t|\t0.108\t|\t100.00%\t|\t0.112\t|\t100.00%\t|\t1.033\n",
"14\t|\t0.004s\t|\t0.103\t|\t100.00%\t|\t0.102\t|\t100.00%\t|\t0.984\n",
"15\t|\t0.004s\t|\t0.096\t|\t100.00%\t|\t0.091\t|\t100.00%\t|\t0.954\n",
"16\t|\t0.004s\t|\t0.086\t|\t100.00%\t|\t0.087\t|\t100.00%\t|\t1.005\n",
"17\t|\t0.004s\t|\t0.096\t|\t100.00%\t|\t0.093\t|\t100.00%\t|\t0.966\n",
"18\t|\t0.004s\t|\t0.085\t|\t100.00%\t|\t0.088\t|\t100.00%\t|\t1.025\n",
"19\t|\t0.004s\t|\t0.085\t|\t100.00%\t|\t0.089\t|\t100.00%\t|\t1.045\n",
"20\t|\t0.004s\t|\t0.073\t|\t100.00%\t|\t0.081\t|\t100.00%\t|\t1.099\n",
"21\t|\t0.004s\t|\t0.076\t|\t100.00%\t|\t0.073\t|\t100.00%\t|\t0.961\n",
"22\t|\t0.004s\t|\t0.074\t|\t100.00%\t|\t0.071\t|\t100.00%\t|\t0.965\n",
"23\t|\t0.004s\t|\t0.067\t|\t100.00%\t|\t0.065\t|\t100.00%\t|\t0.980\n",
"24\t|\t0.004s\t|\t0.072\t|\t100.00%\t|\t0.063\t|\t100.00%\t|\t0.870\n",
"25\t|\t0.004s\t|\t0.066\t|\t100.00%\t|\t0.062\t|\t100.00%\t|\t0.933\n",
"26\t|\t0.004s\t|\t0.068\t|\t100.00%\t|\t0.068\t|\t100.00%\t|\t0.993"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
" 26%|██▌ | 26/100 [00:00<00:00, 254.12it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"27\t|\t0.004s\t|\t0.059\t|\t100.00%\t|\t0.063\t|\t100.00%\t|\t1.069\n",
"28\t|\t0.004s\t|\t0.058\t|\t100.00%\t|\t0.054\t|\t100.00%\t|\t0.921\n",
"29\t|\t0.004s\t|\t0.060\t|\t100.00%\t|\t0.058\t|\t100.00%\t|\t0.971\n",
"30\t|\t0.004s\t|\t0.056\t|\t100.00%\t|\t0.053\t|\t100.00%\t|\t0.943\n",
"31\t|\t0.004s\t|\t0.051\t|\t100.00%\t|\t0.056\t|\t100.00%\t|\t1.099\n",
"32\t|\t0.004s\t|\t0.050\t|\t100.00%\t|\t0.047\t|\t100.00%\t|\t0.939\n",
"33\t|\t0.004s\t|\t0.048\t|\t100.00%\t|\t0.049\t|\t100.00%\t|\t1.005\n",
"34\t|\t0.004s\t|\t0.047\t|\t100.00%\t|\t0.049\t|\t100.00%\t|\t1.037\n",
"35\t|\t0.004s\t|\t0.052\t|\t100.00%\t|\t0.047\t|\t100.00%\t|\t0.895\n",
"36\t|\t0.004s\t|\t0.050\t|\t100.00%\t|\t0.046\t|\t100.00%\t|\t0.921\n",
"37\t|\t0.004s\t|\t0.047\t|\t100.00%\t|\t0.040\t|\t100.00%\t|\t0.850\n",
"38\t|\t0.004s\t|\t0.043\t|\t100.00%\t|\t0.046\t|\t100.00%\t|\t1.070\n",
"39\t|\t0.004s\t|\t0.044\t|\t100.00%\t|\t0.041\t|\t100.00%\t|\t0.925"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
" 52%|█████▏ | 52/100 [00:00<00:00, 255.42it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"40\t|\t0.004s\t|\t0.040\t|\t100.00%\t|\t0.039\t|\t100.00%\t|\t0.991\n",
"41\t|\t0.004s\t|\t0.042\t|\t100.00%\t|\t0.040\t|\t100.00%\t|\t0.957\n",
"42\t|\t0.004s\t|\t0.043\t|\t100.00%\t|\t0.044\t|\t100.00%\t|\t1.030\n",
"43\t|\t0.004s\t|\t0.039\t|\t100.00%\t|\t0.044\t|\t100.00%\t|\t1.108\n",
"44\t|\t0.004s\t|\t0.035\t|\t100.00%\t|\t0.040\t|\t100.00%\t|\t1.155\n",
"45\t|\t0.004s\t|\t0.035\t|\t100.00%\t|\t0.036\t|\t100.00%\t|\t1.008\n",
"46\t|\t0.004s\t|\t0.033\t|\t100.00%\t|\t0.044\t|\t100.00%\t|\t1.351\n",
"47\t|\t0.004s\t|\t0.037\t|\t100.00%\t|\t0.035\t|\t100.00%\t|\t0.933\n",
"48\t|\t0.004s\t|\t0.035\t|\t100.00%\t|\t0.029\t|\t100.00%\t|\t0.839\n",
"49\t|\t0.004s\t|\t0.041\t|\t100.00%\t|\t0.035\t|\t100.00%\t|\t0.858\n",
"50\t|\t0.004s\t|\t0.036\t|\t100.00%\t|\t0.036\t|\t100.00%\t|\t0.979\n",
"51\t|\t0.004s\t|\t0.037\t|\t100.00%\t|\t0.036\t|\t100.00%\t|\t0.959\n",
"52\t|\t0.004s\t|\t0.031\t|\t100.00%\t|\t0.037\t|\t100.00%\t|\t1.172\n",
"53\t|\t0.004s\t|\t0.033\t|\t100.00%\t|\t0.041\t|\t100.00%\t|\t1.242\n",
"54\t|\t0.004s\t|\t0.033\t|\t100.00%\t|\t0.041\t|\t100.00%\t|\t1.225\n",
"55\t|\t0.004s\t|\t0.036\t|\t100.00%\t|\t0.028\t|\t100.00%\t|\t0.795\n",
"56\t|\t0.004s\t|\t0.031\t|\t100.00%\t|\t0.029\t|\t100.00%\t|\t0.924\n",
"57\t|\t0.004s\t|\t0.031\t|\t100.00%\t|\t0.029\t|\t100.00%\t|\t0.921\n",
"58\t|\t0.004s\t|\t0.025\t|\t100.00%\t|\t0.029\t|\t100.00%\t|\t1.180\n",
"59\t|\t0.004s\t|\t0.027\t|\t100.00%\t|\t0.033\t|\t100.00%\t|\t1.194\n",
"60\t|\t0.004s\t|\t0.026\t|\t100.00%\t|\t0.025\t|\t100.00%\t|\t0.941\n",
"61\t|\t0.004s\t|\t0.029\t|\t100.00%\t|\t0.032\t|\t100.00%\t|\t1.125\n",
"62\t|\t0.004s\t|\t0.025\t|\t100.00%\t|\t0.026\t|\t100.00%\t|\t1.028\n",
"63\t|\t0.004s\t|\t0.029\t|\t100.00%\t|\t0.029\t|\t100.00%\t|\t1.001\n",
"64\t|\t0.004s\t|\t0.025\t|\t100.00%\t|\t0.028\t|\t100.00%\t|\t1.118\n",
"65\t|\t0.004s\t|\t0.026\t|\t100.00%\t|\t0.028\t|\t100.00%\t|\t1.048\n",
"66\t|\t0.004s\t|\t0.033\t|\t100.00%\t|\t0.026\t|\t100.00%\t|\t0.792\n",
"67\t|\t0.004s\t|\t0.026\t|\t100.00%\t|\t0.026\t|\t100.00%\t|\t0.986"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
" 78%|███████▊ | 78/100 [00:00<00:00, 255.56it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"68\t|\t0.004s\t|\t0.025\t|\t100.00%\t|\t0.025\t|\t100.00%\t|\t1.010\n",
"69\t|\t0.004s\t|\t0.023\t|\t100.00%\t|\t0.029\t|\t100.00%\t|\t1.216\n",
"70\t|\t0.004s\t|\t0.023\t|\t100.00%\t|\t0.028\t|\t100.00%\t|\t1.209\n",
"71\t|\t0.004s\t|\t0.026\t|\t100.00%\t|\t0.024\t|\t100.00%\t|\t0.927\n",
"72\t|\t0.004s\t|\t0.026\t|\t100.00%\t|\t0.023\t|\t100.00%\t|\t0.872\n",
"73\t|\t0.004s\t|\t0.024\t|\t100.00%\t|\t0.030\t|\t100.00%\t|\t1.268\n",
"74\t|\t0.004s\t|\t0.022\t|\t100.00%\t|\t0.028\t|\t100.00%\t|\t1.244\n",
"75\t|\t0.004s\t|\t0.024\t|\t100.00%\t|\t0.022\t|\t100.00%\t|\t0.924\n",
"76\t|\t0.004s\t|\t0.026\t|\t100.00%\t|\t0.021\t|\t100.00%\t|\t0.820\n",
"77\t|\t0.004s\t|\t0.024\t|\t100.00%\t|\t0.024\t|\t100.00%\t|\t0.985\n",
"78\t|\t0.004s\t|\t0.026\t|\t100.00%\t|\t0.024\t|\t100.00%\t|\t0.913\n",
"79\t|\t0.004s\t|\t0.022\t|\t100.00%\t|\t0.018\t|\t100.00%\t|\t0.812\n",
"80\t|\t0.004s\t|\t0.028\t|\t100.00%\t|\t0.021\t|\t100.00%\t|\t0.732\n",
"81\t|\t0.004s\t|\t0.019\t|\t100.00%\t|\t0.022\t|\t100.00%\t|\t1.191\n",
"82\t|\t0.004s\t|\t0.021\t|\t100.00%\t|\t0.021\t|\t100.00%\t|\t0.985\n",
"83\t|\t0.004s\t|\t0.022\t|\t100.00%\t|\t0.021\t|\t100.00%\t|\t0.936\n",
"84\t|\t0.004s\t|\t0.024\t|\t100.00%\t|\t0.020\t|\t100.00%\t|\t0.829\n",
"85\t|\t0.004s\t|\t0.025\t|\t100.00%\t|\t0.021\t|\t100.00%\t|\t0.818\n",
"86\t|\t0.004s\t|\t0.023\t|\t100.00%\t|\t0.022\t|\t100.00%\t|\t0.937\n",
"87\t|\t0.004s\t|\t0.021\t|\t100.00%\t|\t0.019\t|\t100.00%\t|\t0.929\n",
"88\t|\t0.004s\t|\t0.023\t|\t100.00%\t|\t0.025\t|\t100.00%\t|\t1.088\n",
"89\t|\t0.004s\t|\t0.023\t|\t100.00%\t|\t0.022\t|\t100.00%\t|\t0.960\n",
"90\t|\t0.004s\t|\t0.019\t|\t100.00%\t|\t0.023\t|\t100.00%\t|\t1.231\n",
"91\t|\t0.004s\t|\t0.023\t|\t100.00%\t|\t0.019\t|\t100.00%\t|\t0.801\n",
"92\t|\t0.004s\t|\t0.019\t|\t100.00%\t|\t0.022\t|\t100.00%\t|\t1.125\n",
"93\t|\t0.004s\t|\t0.019\t|\t100.00%\t|\t0.020\t|\t100.00%\t|\t1.019\n",
"94\t|\t0.004s\t|\t0.025\t|\t100.00%\t|\t0.018\t|\t100.00%\t|\t0.723"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
"100%|██████████| 100/100 [00:00<00:00, 255.85it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"95\t|\t0.004s\t|\t0.017\t|\t100.00%\t|\t0.018\t|\t100.00%\t|\t1.067\n",
"96\t|\t0.004s\t|\t0.018\t|\t100.00%\t|\t0.021\t|\t100.00%\t|\t1.193\n",
"97\t|\t0.004s\t|\t0.019\t|\t100.00%\t|\t0.020\t|\t100.00%\t|\t1.032\n",
"98\t|\t0.004s\t|\t0.015\t|\t100.00%\t|\t0.021\t|\t100.00%\t|\t1.397\n",
"99\t|\t0.004s\t|\t0.018\t|\t100.00%\t|\t0.018\t|\t100.00%\t|\t0.982\n",
"100\t|\t0.004s\t|\t0.017\t|\t100.00%\t|\t0.018\t|\t100.00%\t|\t1.057\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"from tqdm import tqdm\n",
"import time\n",
"#tqdm()\n",
"\n",
"print(\"Starting training... for \"+str(N_BATCHES)+\" batches\\n\")\n",
"print(\"Epoch\\t|\\tTime\\t| Training loss\\t| Training acc\\t| Valid. loss\\t| Valid. Accuracy | Valid/Train loss\")\n",
"print(\"________________________________________________________________________________________________________\")\n",
"for epoch in tqdm(range(100)):\n",
" start_time = time.time()\n",
" \n",
" train_loss = 0\n",
" train_acc = 0\n",
" for _ in range(N_BATCHES):\n",
" X, y = next(train_batches)\n",
" loss, acc = f_train(X, y)\n",
" train_loss += loss\n",
" train_acc += acc\n",
" train_loss /= N_BATCHES\n",
" train_acc /= N_BATCHES\n",
"\n",
" val_loss = 0\n",
" val_acc = 0\n",
"\n",
" for _ in range(N_BATCHES):\n",
" X, y = next(train_batches)\n",
" loss, acc = f_val(X, y)\n",
" val_loss += loss\n",
" val_acc += acc\n",
" val_loss /= N_BATCHES\n",
" val_acc /= N_BATCHES\n",
"\n",
" print(\"{}\\t|\\t{:.3f}s\\t|\\t{:.3f}\\t|\\t{:.2f}%\\t|\\t{:.3f}\\t|\\t{:.2f}%\\t|\\t{:.03f}\".format(\n",
" epoch + 1,time.time() - start_time,\n",
" float(train_loss),float(train_acc)*100,float(val_loss),val_acc * 100, val_loss/train_loss)) \n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### ORIGINAL COMPLETE EXAMPLE"
]
},
{
"cell_type": "code",
"execution_count": 91,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 0, Train (val) loss 0.827 (0.804) ratio 0.972\n",
"Train (val) accuracy 0.650 (0.692)\n",
"Epoch 1, Train (val) loss 0.518 (0.562) ratio 1.085\n",
"Train (val) accuracy 0.750 (0.667)\n",
"Epoch 2, Train (val) loss 0.641 (0.683) ratio 1.066\n",
"Train (val) accuracy 0.617 (0.708)\n",
"Epoch 3, Train (val) loss 0.630 (0.503) ratio 0.799\n",
"Train (val) accuracy 0.650 (0.650)\n",
"Epoch 4, Train (val) loss 0.513 (0.637) ratio 1.241\n",
"Train (val) accuracy 0.742 (0.700)\n",
"Epoch 5, Train (val) loss 0.509 (0.668) ratio 1.311\n",
"Train (val) accuracy 0.792 (0.700)\n",
"Epoch 6, Train (val) loss 0.497 (0.459) ratio 0.924\n",
"Train (val) accuracy 0.750 (0.667)\n",
"Epoch 7, Train (val) loss 0.567 (0.363) ratio 0.640\n",
"Train (val) accuracy 0.683 (0.908)\n",
"Epoch 8, Train (val) loss 0.450 (0.378) ratio 0.839\n",
"Train (val) accuracy 0.817 (0.767)\n",
"Epoch 9, Train (val) loss 0.481 (0.465) ratio 0.967\n",
"Train (val) accuracy 0.733 (0.633)\n"
]
}
],
"source": [
"\n",
"import numpy as np\n",
"import theano\n",
"import theano.tensor as T\n",
"import lasagne\n",
"\n",
"from sklearn.datasets import load_iris\n",
"from sklearn.cross_validation import train_test_split\n",
"\n",
"\n",
"def to_categorical(y, nb_classes=None):\n",
" y = np.asarray(y, dtype='int32')\n",
" if not nb_classes:\n",
" nb_classes = np.max(y)+1\n",
" Y = np.zeros((len(y), nb_classes))\n",
" for i in range(len(y)):\n",
" Y[i, y[i]] = 1.\n",
" return Y\n",
"\n",
"# Load and store features as X and targets as y\n",
"iris = load_iris()\n",
"X = iris.data\n",
"y = iris.target\n",
"\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.8)\n",
"X_val = X_test[:15]\n",
"y_val = y_test[:15]\n",
"X_test = X_test[15:]\n",
"y_test = y_test[15:]\n",
"\n",
"# Categorize them for use in categorical cross-entropy\n",
"categorized_train = np.asarray(y_train, dtype='int32')\n",
"categorized_test = np.asarray(y_test, dtype='int32')\n",
"categorized_val = np.asarray(y_val, dtype='int32')\n",
"\n",
"def batch_gen(X, y, N):\n",
" while True:\n",
" idx = np.random.choice(len(y), N)\n",
" yield X[idx].astype('float32'), y[idx].astype('int32')\n",
"\n",
"\n",
"# 4-dim vector on input\n",
"l_in = lasagne.layers.InputLayer((None, 4))\n",
"# 3-dim ivector on output\n",
"l_out = lasagne.layers.DenseLayer(l_in, num_units=3, nonlinearity=lasagne.nonlinearities.softmax)\n",
"\n",
"X_sym = T.matrix('X')\n",
"y_sym = T.ivector('y')\n",
"\n",
"# Expression for the output distribution\n",
"output = lasagne.layers.get_output(l_out, X_sym)\n",
"pred = output.argmax(-1)\n",
"\n",
"# Loss function\n",
"loss = T.mean(lasagne.objectives.categorical_crossentropy(output, y_sym))\n",
"acc = T.mean(T.eq(pred, y_sym))\n",
"\n",
"# We retrieve the parameters\n",
"params = lasagne.layers.get_all_params(l_out)\n",
"\n",
"# Compute the gradient of the loss function with respect to the parameters.\n",
"# The stochastic gradient descent updates the parameters\n",
"grad = T.grad(loss, params)\n",
"updates = lasagne.updates.sgd(grad, params, learning_rate=0.05)\n",
"\n",
"# Define a training function\n",
"f_train = theano.function([X_sym, y_sym], [loss, acc], updates=updates)\n",
"\n",
"# A validation function, similar but it doesn't alter the parameters\n",
"f_val = theano.function([X_sym, y_sym], [loss, acc])\n",
"\n",
"# Prediction function,\n",
"f_predict = theano.function([X_sym], pred)\n",
"\n",
"# Batch size choice and the number of batches per epoch\n",
"BATCH_SIZE = 5\n",
"N_BATCHES = len(X_train) // BATCH_SIZE\n",
"N_VAL_BATCHES = len(X_val) // BATCH_SIZE\n",
"\n",
"# Minibatch generators\n",
"train_batches = batch_gen(X_train, categorized_train, BATCH_SIZE)\n",
"val_batches = batch_gen(X_val, categorized_val, BATCH_SIZE)\n",
"\n",
"for epoch in range(10):\n",
" train_loss = 0\n",
" train_acc = 0\n",
" for _ in range(N_BATCHES):\n",
" X, y = next(train_batches)\n",
" loss, acc = f_train(X, y)\n",
" train_loss += loss\n",
" train_acc += acc\n",
" train_loss /= N_BATCHES\n",
" train_acc /= N_BATCHES\n",
"\n",
" val_loss = 0\n",
" val_acc = 0\n",
"\n",
" for _ in range(N_BATCHES):\n",
" X, y = next(train_batches)\n",
" loss, acc = f_val(X, y)\n",
" val_loss += loss\n",
" val_acc += acc\n",
" val_loss /= N_BATCHES\n",
" val_acc /= N_BATCHES\n",
"\n",
" print('Epoch {}, Train (val) loss {:.03f} ({:.03f}) ratio {:.03f}'.format(\n",
" epoch, train_loss, val_loss, val_loss/train_loss))\n",
" print('Train (val) accuracy {:.03f} ({:.03f})'.format(train_acc, val_acc))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"\n",
"\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.11"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment