Created
March 22, 2017 01:30
-
-
Save yamaguchiyuto/155675e2f75a2d82e1232fb539608c9e to your computer and use it in GitHub Desktop.
NCTM experiment
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"from nctm import NCTM\n", | |
"from ctm import CTM" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"# データ生成\n", | |
"vocab = {'computer':0, 'banana':1, 'ipad':2, 'orange':3, 'apple':4}\n", | |
"categories = {'FOODS':0, 'TECH':1, 'TOREAD':2}\n", | |
"\n", | |
"Vw = len(vocab)\n", | |
"Vx = len(categories)\n", | |
"\n", | |
"W = [] # words\n", | |
"X = [] # categories\n", | |
"\n", | |
"W.append(map(lambda v:vocab[v], ['computer', 'ipad', 'apple']))\n", | |
"X.append(map(lambda c:categories[c], ['TECH', 'TOREAD']))\n", | |
"W.append(map(lambda v:vocab[v], ['ipad', 'ipad']))\n", | |
"X.append(map(lambda c:categories[c], ['TECH', 'TOREAD']))\n", | |
"W.append(map(lambda v:vocab[v], ['ipad', 'ipad', 'apple', 'apple']))\n", | |
"X.append(map(lambda c:categories[c], ['TOREAD']))\n", | |
"W.append(map(lambda v:vocab[v], ['banana', 'orange', 'apple']))\n", | |
"X.append(map(lambda c:categories[c], ['FOODS', 'TOREAD']))\n", | |
"W.append(map(lambda v:vocab[v], ['banana', 'orange']))\n", | |
"X.append(map(lambda c:categories[c], ['FOODS']))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"# モデル定義\n", | |
"K = 2\n", | |
"alpha = 0.1\n", | |
"beta= 0.1\n", | |
"gamma= 0.1\n", | |
"eta = 1.0\n", | |
"max_iter = 100\n", | |
"\n", | |
"ctm = CTM(K=K, alpha=alpha, beta=beta, gamma=gamma, max_iter=max_iter)\n", | |
"nctm = NCTM(K=K, alpha=alpha, beta=beta, gamma=gamma, eta=eta, max_iter=max_iter)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<nctm.NCTM instance at 0x111576758>" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# NCTM のフィッティング\n", | |
"nctm.fit(W,X,Vw,Vx)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[array([0, 0, 0]), array([0, 0]), array([0, 0, 0, 0]), array([1, 1, 1]), array([1, 1])]\n", | |
"[array([0, 0]), array([0, 0]), array([0]), array([1, 1]), array([1])]\n", | |
"[array([1, 0]), array([1, 0]), array([0]), array([1, 0]), array([1])]\n" | |
] | |
} | |
], | |
"source": [ | |
"# うまくトピック割り当てができているし\n", | |
"# 付加情報 \"TOREAD\" はノイズと判別されている\n", | |
"print nctm.Z\n", | |
"print nctm.Y\n", | |
"print nctm.R" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<ctm.CTM instance at 0x1115768c0>" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# CTM のフィッティング\n", | |
"ctm.fit(W,X,Vw,Vx)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[array([1, 0, 0]), array([0, 0]), array([0, 0, 0, 0]), array([1, 1, 0]), array([1, 1])]\n", | |
"[array([0, 0]), array([0, 0]), array([0]), array([1, 0]), array([1])]\n" | |
] | |
} | |
], | |
"source": [ | |
"# 付加情報 \"TOREAD\" にトピック 0 が割り当てられているせいで\n", | |
"# 4番目の文書の単語 \"apple\" もトピック 0 になってしまっている\n", | |
"print ctm.Z\n", | |
"print ctm.Y" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 2", | |
"language": "python", | |
"name": "python2" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 2 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython2", | |
"version": "2.7.11" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment