Last active
July 17, 2016 04:47
-
-
Save kwatch/e31c0e290979ee1e7429ae0e14a16af6 to your computer and use it in GitHub Desktop.
パーセプトロンのサンプルコード
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
## | |
## 参考: 「Python機械学習プログラミング」 | |
## | |
## 注: 事前にCSVファイルをダウンロードしておくこと | |
## $ wget https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data | |
## | |
class Perceptron(object): | |
def __init__(self, eta=0.01, n_iter=10): | |
assert 0.0 < eta <= 1.0 | |
self.eta = eta # 学習率 (0.0 < eta <= 1.0) | |
self.n_iter = n_iter # 学習を繰り返す回数 | |
def fit(self, input__, actual_): | |
assert len(input__) > 0 | |
assert len(input__) == len(actual_) | |
# 入力値の重みを初期化する | |
n_features = len(input__[0]) # 特徴量の種類 (ex: 2) | |
weight_ = [0.0] * (1 + n_features) # 初期値 [0.0, 0.0, 0.0] | |
# 学習を繰り返す | |
missed_ = [] # 推測が外れた回数の履歴 | |
eta = self.eta # 学習率 (== 0.01) | |
for n in range(self.n_iter): | |
print("") | |
print("========== loop #%s ==========" % (n+1)) | |
print("* weight_=%r" % weight_) | |
missed = 0 # 推測が外れた回数を、ループごとにカウント | |
for input_, actual in zip(input__, actual_): | |
# 入力値に重みをつけて、推測値を計算 | |
z = self._net_input(input_, weight_) | |
guess = self._predict(z) #=> 1 or -1 | |
# 推測が外れた場合は重みを調整 | |
if actual != guess: | |
missed += 1 | |
#print([actual, guess]) #=> [1, -1] or [-1, 1] | |
#print(actual - guess) #=> 2 or -2 | |
update = eta * (actual - guess) #=> 0.01*2 or 0.01*(-2) | |
weight_[0] += update * 1.0 # because x0 == 1 | |
for i, x in enumerate(input_, 1): | |
weight_[i] += update * x | |
## or: | |
#i = 0 | |
#for x in input_: | |
# i += 1 | |
# weight_[i] += update * x | |
print("* input_=%r, actual=%r, guess=%r, weight_=%r" % (input_, actual, guess, weight_)) | |
print("* missed=%r" % missed) | |
missed_.append(missed) | |
#print(missed_) #=> ex: [2, 2, 3, 2, 1, 0, 0, 0, 0, 0] | |
#print(weight_) #=> ex: [-0.4 -0.68 1.82] | |
self.weight_ = weight_ # 本と同じにするなら self.w_ = weight_ | |
self.missed_ = missed_ # 本と同じにするなら self.errors_ = missed_ | |
#return self | |
return weight_, missed_ | |
def _net_input(self, input_, weight_): | |
assert len(input_) + 1 == len(weight_) | |
z = 1.0 * weight_[0] + \ | |
sum( x * w for x, w in zip(input_, weight_[1:]) ) | |
return z | |
## or: | |
#z = 1.0 * weight_[0] | |
#for x, w in zip(input_, weight_[1:]): | |
# z += x * w | |
#return z | |
def _predict(self, z): | |
if z >= 0.0: | |
return 1 # 'Iris-virsicolor' | |
else: | |
return -1 # 'Iris-setosa' | |
### | |
### パーセプトロンを実行するサンプルコード | |
### | |
import sys | |
import csv | |
def _debug(name, value): | |
sys.stderr.write("\033[0;31m*** debug: %s=%r\033[0m\n" % (name, value)) | |
class BaseApp(object): | |
def _load_csv_file(self, filename): | |
with open(filename) as f: | |
reader = csv.reader(f) | |
table__ = [ row_ for row_ in reader ] | |
return table__ | |
def _build_plot(self, value_, xlabel, ylabel, marker='o'): | |
import matplotlib.pyplot as plt | |
plt.plot(range(1, len(value_) + 1), value_, marker=marker) | |
plt.xlabel(xlabel) | |
plt.ylabel(ylabel) | |
plt.tight_layout() | |
plt.show() | |
return plt | |
def run(self, *args, **kwargs): | |
raise NotImplementedError("%s.run(): not implemented yet." % self.__class__.__name__) | |
class MainApp(BaseApp): | |
def run(self, filename='iris.data', show_graph=True): | |
# 注: 事前にCSVファイルをダウンロードしておくこと | |
# $ wget https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data | |
table__ = self._load_csv_file(filename) | |
# Iris-setoma と Iris-versicolorselect の、がくの長さと花びらの長さだけを取り出す | |
input__ = [ [float(r_[0]), float(r_[2])] for r_ in table__[0:100] ] | |
# | |
def fn(name): | |
if name == 'Iris-setosa': return -1 | |
elif name == 'Iris-versicolor': return 1 | |
else: | |
raise ValueError("%r: unexpected value" % name) | |
actual_ = [ fn(row_[4]) for row_ in table__[0:100] ] | |
# | |
perceptron = Perceptron(eta=0.1, n_iter=10) | |
weight_, missed_ = perceptron.fit(input__, actual_) | |
print("") | |
print("========== result ==========") | |
print("* weight_=%r" % weight_) | |
print("* missed_=%r" % missed_) | |
# | |
if show_graph: | |
plt = self._build_plot(missed_, 'Epochs', 'Number of misclassifications') | |
# plt.savefig('./perceptron_1.png', dpi=300) | |
plt.show() | |
@classmethod | |
def main(cls): | |
## コマンドライン引数をパース (今回は特になし) | |
args = sys.argv[1:] | |
args = ('iris.data',) | |
kwargs = {'show_graph': True} | |
cls().run(*args, **kwargs) | |
if __name__ == '__main__': | |
MainApp.main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment