Last active
September 18, 2018 11:35
-
-
Save chck/294f06e2032f70d7303fc528564dd17f to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"\"\"\"https://github.com/tbennun/keras-bucketed-sequence.git\n", | |
"Bucketing technique for NLP by Keras\n", | |
"\"\"\"" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from tensorflow.keras.layers import Input, LSTM, Dense\n", | |
"from tensorflow.keras.models import Model\n", | |
"from tensorflow.keras.preprocessing.sequence import pad_sequences\n", | |
"from absl import app, flags\n", | |
"import numpy as np\n", | |
"\n", | |
"from bucketed_sequence import BucketedSequence" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"UNK = -1.0\n", | |
"batch_size = 64\n", | |
"epochs = 20\n", | |
"lstm_units = 100\n", | |
"dense_breadth = 32\n", | |
"\n", | |
"dataset_size = 10000\n", | |
"val_size = 1000\n", | |
"seqlen_mean = 50\n", | |
"seqlen_stddev = 200\n", | |
"\n", | |
"buckets = 10" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[-1.00000000e+000, -1.00000000e+000, -1.00000000e+000,\n", | |
" -1.00000000e+000, -1.00000000e+000, 0.00000000e+000,\n", | |
" 1.73060038e-077, 2.23111331e-314, 2.23093651e-314,\n", | |
" 2.23093663e-314, 2.23077450e-314, 2.23080756e-314,\n", | |
" 0.00000000e+000, 2.15575018e-314, 7.14433543e-309]])" | |
] | |
}, | |
"execution_count": 17, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"def pad(seqs, maxlen):\n", | |
" # Note: prepends data\n", | |
" padded = np.array(pad_sequences(seqs, maxlen=maxlen, value=UNK, dtype=seqs[0].dtype))\n", | |
" return np.vstack([np.expand_dims(x, axis=0) for x in padded])\n", | |
"\n", | |
"pad(seqs=np.empty((1,10)), maxlen=15)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 72, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"((1000, 594, 1), (1000,), (1000,))" | |
] | |
}, | |
"execution_count": 72, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"def gen_dataset(set_size):\n", | |
" sequence_lendths = np.random.normal(loc=seqlen_mean, scale=seqlen_stddev, size=set_size).astype(np.int32)\n", | |
" max_length = np.max(sequence_lendths)\n", | |
" # Clamp range to start from three elements\n", | |
" sequence_lendths = np.clip(sequence_lendths, 3, max_length)\n", | |
" \n", | |
" # Generate random sequences\n", | |
" seq_x = [np.random.uniform(1.0, 50.0, sl) for sl in sequence_lendths]\n", | |
" seq_y = np.array([seq[2] for seq in seq_x], dtype=np.float32)\n", | |
" \n", | |
" # Pad sequences\n", | |
" padded_x = pad(seq_x, max_length)\n", | |
" padded_x = np.reshape(padded_x, (len(sequence_lendths), max_length, 1))\n", | |
" \n", | |
" # Return dataset\n", | |
" return padded_x, seq_y, sequence_lendths\n", | |
"\n", | |
"a, b, c = gen_dataset(set_size=val_size)\n", | |
"a.shape, b.shape, c.shape" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 77, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"image/svg+xml": [ | |
"<svg height=\"264pt\" viewBox=\"0.00 0.00 112.25 264.00\" width=\"112pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n", | |
"<g class=\"graph\" id=\"graph0\" transform=\"scale(1 1) rotate(0) translate(4 260)\">\n", | |
"<title>G</title>\n", | |
"<polygon fill=\"#ffffff\" points=\"-4,4 -4,-260 108.252,-260 108.252,4 -4,4\" stroke=\"transparent\"/>\n", | |
"<!-- 4719619544 -->\n", | |
"<g class=\"node\" id=\"node1\">\n", | |
"<title>4719619544</title>\n", | |
"<polygon fill=\"none\" points=\"3.8896,-219.5 3.8896,-255.5 100.3623,-255.5 100.3623,-219.5 3.8896,-219.5\" stroke=\"#000000\"/>\n", | |
"<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"52.126\" y=\"-233.3\">in: InputLayer</text>\n", | |
"</g>\n", | |
"<!-- 4719354656 -->\n", | |
"<g class=\"node\" id=\"node2\">\n", | |
"<title>4719354656</title>\n", | |
"<polygon fill=\"none\" points=\"9.7036,-146.5 9.7036,-182.5 94.5483,-182.5 94.5483,-146.5 9.7036,-146.5\" stroke=\"#000000\"/>\n", | |
"<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"52.126\" y=\"-160.3\">lstm: LSTM</text>\n", | |
"</g>\n", | |
"<!-- 4719619544->4719354656 -->\n", | |
"<g class=\"edge\" id=\"edge1\">\n", | |
"<title>4719619544->4719354656</title>\n", | |
"<path d=\"M52.126,-219.4551C52.126,-211.3828 52.126,-201.6764 52.126,-192.6817\" fill=\"none\" stroke=\"#000000\"/>\n", | |
"<polygon fill=\"#000000\" points=\"55.6261,-192.5903 52.126,-182.5904 48.6261,-192.5904 55.6261,-192.5903\" stroke=\"#000000\"/>\n", | |
"</g>\n", | |
"<!-- 4719394376 -->\n", | |
"<g class=\"node\" id=\"node3\">\n", | |
"<title>4719394376</title>\n", | |
"<polygon fill=\"none\" points=\"0,-73.5 0,-109.5 104.252,-109.5 104.252,-73.5 0,-73.5\" stroke=\"#000000\"/>\n", | |
"<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"52.126\" y=\"-87.3\">dense_6: Dense</text>\n", | |
"</g>\n", | |
"<!-- 4719354656->4719394376 -->\n", | |
"<g class=\"edge\" id=\"edge2\">\n", | |
"<title>4719354656->4719394376</title>\n", | |
"<path d=\"M52.126,-146.4551C52.126,-138.3828 52.126,-128.6764 52.126,-119.6817\" fill=\"none\" stroke=\"#000000\"/>\n", | |
"<polygon fill=\"#000000\" points=\"55.6261,-119.5903 52.126,-109.5904 48.6261,-119.5904 55.6261,-119.5903\" stroke=\"#000000\"/>\n", | |
"</g>\n", | |
"<!-- 4719489208 -->\n", | |
"<g class=\"node\" id=\"node4\">\n", | |
"<title>4719489208</title>\n", | |
"<polygon fill=\"none\" points=\"0,-.5 0,-36.5 104.252,-36.5 104.252,-.5 0,-.5\" stroke=\"#000000\"/>\n", | |
"<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"52.126\" y=\"-14.3\">dense_7: Dense</text>\n", | |
"</g>\n", | |
"<!-- 4719394376->4719489208 -->\n", | |
"<g class=\"edge\" id=\"edge3\">\n", | |
"<title>4719394376->4719489208</title>\n", | |
"<path d=\"M52.126,-73.4551C52.126,-65.3828 52.126,-55.6764 52.126,-46.6817\" fill=\"none\" stroke=\"#000000\"/>\n", | |
"<polygon fill=\"#000000\" points=\"55.6261,-46.5903 52.126,-36.5904 48.6261,-46.5904 55.6261,-46.5903\" stroke=\"#000000\"/>\n", | |
"</g>\n", | |
"</g>\n", | |
"</svg>" | |
], | |
"text/plain": [ | |
"<IPython.core.display.SVG object>" | |
] | |
}, | |
"execution_count": 77, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"def build_model():\n", | |
" # Set up a single network (LSTM + Dense))\n", | |
" inp = Input(shape=(None, 1), dtype=\"float32\", name=\"in\")\n", | |
" lstm = LSTM(lstm_units, return_sequences=False, name=\"lstm\")(inp)\n", | |
" dense = Dense(dense_breadth, kernel_initializer='normal', activation='relu')(lstm)\n", | |
" outputs = Dense(1, kernel_initializer='normal')(dense)\n", | |
" return Model(inputs=inp, outputs=outputs)\n", | |
"\n", | |
"model = build_model()\n", | |
"model.compile(optimizer=\"adam\", loss=\"mean_squared_error\", metrics=['mae'])\n", | |
"\n", | |
"from IPython.display import SVG\n", | |
"from tensorflow.python.keras.utils.vis_utils import model_to_dot\n", | |
"\n", | |
"SVG(model_to_dot(model).create(prog='dot', format='svg'))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 78, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Generate Dataset\n", | |
"x_train, y_train, len_train = gen_dataset(dataset_size)\n", | |
"x_val, y_val, len_val = gen_dataset(dataset_size)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 80, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Training with 10 non-empty buckets\n", | |
"Training with 10 non-empty buckets\n", | |
"Epoch 1/20\n", | |
"163/163 [==============================] - 35s 218ms/step - loss: 195.1114 - mean_absolute_error: 10.3456 - val_loss: 118.8506 - val_mean_absolute_error: 7.5571\n", | |
"Epoch 2/20\n", | |
"163/163 [==============================] - 36s 219ms/step - loss: 120.9951 - mean_absolute_error: 7.6279 - val_loss: 121.1408 - val_mean_absolute_error: 7.4981\n", | |
"Epoch 3/20\n", | |
"163/163 [==============================] - 37s 230ms/step - loss: 122.0102 - mean_absolute_error: 7.6568 - val_loss: 118.5691 - val_mean_absolute_error: 7.4961\n", | |
"Epoch 4/20\n", | |
"163/163 [==============================] - 37s 229ms/step - loss: 124.0299 - mean_absolute_error: 7.7591 - val_loss: 122.0257 - val_mean_absolute_error: 7.5287\n", | |
"Epoch 5/20\n", | |
"163/163 [==============================] - 39s 239ms/step - loss: 124.0682 - mean_absolute_error: 7.8174 - val_loss: 119.2843 - val_mean_absolute_error: 7.7904\n", | |
"Epoch 6/20\n", | |
"163/163 [==============================] - 37s 229ms/step - loss: 119.0774 - mean_absolute_error: 7.5436 - val_loss: 118.4170 - val_mean_absolute_error: 7.4253\n", | |
"Epoch 7/20\n", | |
"163/163 [==============================] - 44s 271ms/step - loss: 121.2110 - mean_absolute_error: 7.6321 - val_loss: 118.0953 - val_mean_absolute_error: 7.5249\n", | |
"Epoch 8/20\n", | |
"163/163 [==============================] - 43s 263ms/step - loss: 121.8157 - mean_absolute_error: 7.7071 - val_loss: 117.6462 - val_mean_absolute_error: 7.3396\n", | |
"Epoch 9/20\n", | |
"163/163 [==============================] - 41s 252ms/step - loss: 115.4723 - mean_absolute_error: 7.5024 - val_loss: 91.7777 - val_mean_absolute_error: 6.6707\n", | |
"Epoch 10/20\n", | |
"163/163 [==============================] - 47s 286ms/step - loss: 53.6427 - mean_absolute_error: 5.0593 - val_loss: 15.4274 - val_mean_absolute_error: 2.6593\n", | |
"Epoch 11/20\n", | |
"163/163 [==============================] - 37s 227ms/step - loss: 11.2502 - mean_absolute_error: 2.2828 - val_loss: 5.5709 - val_mean_absolute_error: 1.3830\n", | |
"Epoch 12/20\n", | |
"163/163 [==============================] - 36s 224ms/step - loss: 8.2588 - mean_absolute_error: 1.8053 - val_loss: 3.8581 - val_mean_absolute_error: 1.1801\n", | |
"Epoch 13/20\n", | |
"163/163 [==============================] - 36s 224ms/step - loss: 5.0152 - mean_absolute_error: 1.2891 - val_loss: 13.0625 - val_mean_absolute_error: 2.6455\n", | |
"Epoch 14/20\n", | |
"163/163 [==============================] - 36s 224ms/step - loss: 5.0283 - mean_absolute_error: 1.4202 - val_loss: 2.6965 - val_mean_absolute_error: 0.9099\n", | |
"Epoch 15/20\n", | |
"163/163 [==============================] - 36s 224ms/step - loss: 3.2636 - mean_absolute_error: 1.0700 - val_loss: 2.5475 - val_mean_absolute_error: 0.9087\n", | |
"Epoch 16/20\n", | |
"163/163 [==============================] - 37s 226ms/step - loss: 2.5395 - mean_absolute_error: 0.9430 - val_loss: 1.8564 - val_mean_absolute_error: 0.7822\n", | |
"Epoch 17/20\n", | |
"163/163 [==============================] - 38s 231ms/step - loss: 2.4931 - mean_absolute_error: 0.9311 - val_loss: 1.3757 - val_mean_absolute_error: 0.6267\n", | |
"Epoch 18/20\n", | |
"163/163 [==============================] - 37s 230ms/step - loss: 2.2431 - mean_absolute_error: 0.8908 - val_loss: 1.2690 - val_mean_absolute_error: 0.6079\n", | |
"Epoch 19/20\n", | |
"163/163 [==============================] - 37s 229ms/step - loss: 3.3244 - mean_absolute_error: 1.1000 - val_loss: 1.8255 - val_mean_absolute_error: 0.8506\n", | |
"Epoch 20/20\n", | |
"163/163 [==============================] - 36s 223ms/step - loss: 1.4627 - mean_absolute_error: 0.7011 - val_loss: 1.1646 - val_mean_absolute_error: 0.5915\n" | |
] | |
} | |
], | |
"source": [ | |
"if buckets > 0:\n", | |
" # Create Sequence objects\n", | |
" train_generator = BucketedSequence(buckets, batch_size, len_train, x_train, y_train)\n", | |
" val_generator = BucketedSequence(buckets, batch_size, len_val, x_val, y_val)\n", | |
" \n", | |
" model.fit_generator(train_generator, \n", | |
" epochs=epochs, \n", | |
" validation_data=val_generator, \n", | |
" shuffle=True, \n", | |
" verbose=True)\n", | |
" \n", | |
"else:\n", | |
" model.fit(x=x_train, y=y_train, \n", | |
" epochs=epochs, \n", | |
" validation_data=(x_val, y_val), \n", | |
" batch_size=batch_size, \n", | |
" verbose=True, \n", | |
" shuffle=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.5.5" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment