Created
October 27, 2020 14:11
-
-
Save ita9naiwa/b328c43508193611a83c07ae0553a9f3 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import implicit\n", | |
"import pickle" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from implicit.cml import CollaborativeMetricLearning\n", | |
"from implicit.als import AlternatingLeastSquares\n", | |
"from implicit.lmf import LogisticMatrixFactorization\n", | |
"from implicit.evaluation import *\n", | |
"from implicit.datasets.sketchfab import get_sketchfab\n", | |
"from implicit.datasets.movielens import get_movielens" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"#_, _, mat = get_sketchfab()\n", | |
"_, mat = get_movielens(variant='1m')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"seed=1541" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"mat.data[:] = 1" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"tr, te = train_test_split(mat, 0.8)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [], | |
"source": [ | |
"m2 = CollaborativeMetricLearning(factors=64, \n", | |
" threshold=1.0,\n", | |
" learning_rate=0.1, \n", | |
" iterations=15, \n", | |
" num_threads=8, \n", | |
" regularization=0.00,\n", | |
" neg_sampling=100,\n", | |
" random_state=seed)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 15/15 [00:40<00:00, 2.79s/it]\n" | |
] | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "0c6385a9bd6f426f8012661292f6b607", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"HBox(children=(IntProgress(value=0, max=3000), HTML(value='')))" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"\n", | |
"{'precision': 0.2958226060958392, 'map': 0.20293504045410232, 'ndcg': 0.28618070893351455, 'auc': 0.5181144038465575}\n" | |
] | |
} | |
], | |
"source": [ | |
"m2.fit(tr.T, True)\n", | |
"print(ranking_metrics_at_k(m2, tr[:3000], te[:3000], 5))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"m2 = CollaborativeMetricLearning(factors=64, \n", | |
" threshold=1.0,\n", | |
" learning_rate=0.1, \n", | |
" iterations=15, \n", | |
" num_threads=8, \n", | |
" regularization=0.01,\n", | |
" neg_sampling=100,\n", | |
" random_state=seed)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 15/15 [00:43<00:00, 3.08s/it]\n" | |
] | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "4b94fa55352f4e7bba4a0ec9c9bff749", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"HBox(children=(IntProgress(value=0, max=3000), HTML(value='')))" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"\n", | |
"{'precision': 0.33095257868947764, 'map': 0.23899023594696514, 'ndcg': 0.32266404435662915, 'auc': 0.5204181639238991}\n" | |
] | |
} | |
], | |
"source": [ | |
"m2.fit(tr.T, True)\n", | |
"print(ranking_metrics_at_k(m2, tr[:3000], te[:3000], 5))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"m2 = CollaborativeMetricLearning(factors=64, \n", | |
" threshold=1.0,\n", | |
" learning_rate=0.3, \n", | |
" iterations=15, \n", | |
" num_threads=8, \n", | |
" regularization=0.03,\n", | |
" neg_sampling=100,\n", | |
" random_state=seed)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 15/15 [01:03<00:00, 4.25s/it]\n" | |
] | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "74a1d29d18034cdfbcd2d1292e5de7a7", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"HBox(children=(IntProgress(value=0, max=3000), HTML(value='')))" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"\n", | |
"{'precision': 0.37646374885806827, 'map': 0.28215497521921507, 'ndcg': 0.3662162875887344, 'auc': 0.5236010609909335}\n" | |
] | |
} | |
], | |
"source": [ | |
"m2.fit(tr.T, True)\n", | |
"print(ranking_metrics_at_k(m2, tr[:3000], te[:3000], 5))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"m2 = CollaborativeMetricLearning(factors=64, \n", | |
" threshold=1.0,\n", | |
" learning_rate=0.1, \n", | |
" iterations=15, \n", | |
" num_threads=8, \n", | |
" regularization=0.05,\n", | |
" neg_sampling=100,\n", | |
" random_state=seed)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 15/15 [00:54<00:00, 3.70s/it]\n" | |
] | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "641d0792eeb0442b9dc36aeb7f2b2601", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"HBox(children=(IntProgress(value=0, max=3000), HTML(value='')))" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"\n", | |
"{'precision': 0.37380616227888047, 'map': 0.2818614182234088, 'ndcg': 0.36362742743417237, 'auc': 0.5222198795750684}\n" | |
] | |
} | |
], | |
"source": [ | |
"m2.fit(tr.T, True)\n", | |
"print(ranking_metrics_at_k(m2, tr[:3000], te[:3000], 5))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"m2 = CollaborativeMetricLearning(factors=64, \n", | |
" threshold=1.0,\n", | |
" learning_rate=0.1, \n", | |
" iterations=15, \n", | |
" num_threads=8, \n", | |
" regularization=0.1,\n", | |
" neg_sampling=100,\n", | |
" random_state=seed)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 15/15 [00:55<00:00, 3.85s/it]\n" | |
] | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "acc87622b9974a4493f5257591f6ec11", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"HBox(children=(IntProgress(value=0, max=3000), HTML(value='')))" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"\n", | |
"{'precision': 0.3812806245328461, 'map': 0.28596698013301086, 'ndcg': 0.36791157480005765, 'auc': 0.5231909268136861}\n" | |
] | |
} | |
], | |
"source": [ | |
"m2.fit(tr.T, True)\n", | |
"print(ranking_metrics_at_k(m2, tr[:3000], te[:3000], 5))" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.3" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment