Skip to content

Instantly share code, notes, and snippets.

@kwatch
Last active July 17, 2016 04:47
Show Gist options
  • Save kwatch/e31c0e290979ee1e7429ae0e14a16af6 to your computer and use it in GitHub Desktop.
Save kwatch/e31c0e290979ee1e7429ae0e14a16af6 to your computer and use it in GitHub Desktop.
パーセプトロンのサンプルコード
# -*- coding: utf-8 -*-
##
## 参考: 「Python機械学習プログラミング」
##
## 注: 事前にCSVファイルをダウンロードしておくこと
## $ wget https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data
##
class Perceptron(object):
def __init__(self, eta=0.01, n_iter=10):
assert 0.0 < eta <= 1.0
self.eta = eta # 学習率 (0.0 < eta <= 1.0)
self.n_iter = n_iter # 学習を繰り返す回数
def fit(self, input__, actual_):
assert len(input__) > 0
assert len(input__) == len(actual_)
# 入力値の重みを初期化する
n_features = len(input__[0]) # 特徴量の種類 (ex: 2)
weight_ = [0.0] * (1 + n_features) # 初期値 [0.0, 0.0, 0.0]
# 学習を繰り返す
missed_ = [] # 推測が外れた回数の履歴
eta = self.eta # 学習率 (== 0.01)
for n in range(self.n_iter):
print("")
print("========== loop #%s ==========" % (n+1))
print("* weight_=%r" % weight_)
missed = 0 # 推測が外れた回数を、ループごとにカウント
for input_, actual in zip(input__, actual_):
# 入力値に重みをつけて、推測値を計算
z = self._net_input(input_, weight_)
guess = self._predict(z) #=> 1 or -1
# 推測が外れた場合は重みを調整
if actual != guess:
missed += 1
#print([actual, guess]) #=> [1, -1] or [-1, 1]
#print(actual - guess) #=> 2 or -2
update = eta * (actual - guess) #=> 0.01*2 or 0.01*(-2)
weight_[0] += update * 1.0 # because x0 == 1
for i, x in enumerate(input_, 1):
weight_[i] += update * x
## or:
#i = 0
#for x in input_:
# i += 1
# weight_[i] += update * x
print("* input_=%r, actual=%r, guess=%r, weight_=%r" % (input_, actual, guess, weight_))
print("* missed=%r" % missed)
missed_.append(missed)
#print(missed_) #=> ex: [2, 2, 3, 2, 1, 0, 0, 0, 0, 0]
#print(weight_) #=> ex: [-0.4 -0.68 1.82]
self.weight_ = weight_ # 本と同じにするなら self.w_ = weight_
self.missed_ = missed_ # 本と同じにするなら self.errors_ = missed_
#return self
return weight_, missed_
def _net_input(self, input_, weight_):
assert len(input_) + 1 == len(weight_)
z = 1.0 * weight_[0] + \
sum( x * w for x, w in zip(input_, weight_[1:]) )
return z
## or:
#z = 1.0 * weight_[0]
#for x, w in zip(input_, weight_[1:]):
# z += x * w
#return z
def _predict(self, z):
if z >= 0.0:
return 1 # 'Iris-virsicolor'
else:
return -1 # 'Iris-setosa'
###
### パーセプトロンを実行するサンプルコード
###
import sys
import csv
def _debug(name, value):
sys.stderr.write("\033[0;31m*** debug: %s=%r\033[0m\n" % (name, value))
class BaseApp(object):
def _load_csv_file(self, filename):
with open(filename) as f:
reader = csv.reader(f)
table__ = [ row_ for row_ in reader ]
return table__
def _build_plot(self, value_, xlabel, ylabel, marker='o'):
import matplotlib.pyplot as plt
plt.plot(range(1, len(value_) + 1), value_, marker=marker)
plt.xlabel(xlabel)
plt.ylabel(ylabel)
plt.tight_layout()
plt.show()
return plt
def run(self, *args, **kwargs):
raise NotImplementedError("%s.run(): not implemented yet." % self.__class__.__name__)
class MainApp(BaseApp):
def run(self, filename='iris.data', show_graph=True):
# 注: 事前にCSVファイルをダウンロードしておくこと
# $ wget https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data
table__ = self._load_csv_file(filename)
# Iris-setoma と Iris-versicolorselect の、がくの長さと花びらの長さだけを取り出す
input__ = [ [float(r_[0]), float(r_[2])] for r_ in table__[0:100] ]
#
def fn(name):
if name == 'Iris-setosa': return -1
elif name == 'Iris-versicolor': return 1
else:
raise ValueError("%r: unexpected value" % name)
actual_ = [ fn(row_[4]) for row_ in table__[0:100] ]
#
perceptron = Perceptron(eta=0.1, n_iter=10)
weight_, missed_ = perceptron.fit(input__, actual_)
print("")
print("========== result ==========")
print("* weight_=%r" % weight_)
print("* missed_=%r" % missed_)
#
if show_graph:
plt = self._build_plot(missed_, 'Epochs', 'Number of misclassifications')
# plt.savefig('./perceptron_1.png', dpi=300)
plt.show()
@classmethod
def main(cls):
## コマンドライン引数をパース (今回は特になし)
args = sys.argv[1:]
args = ('iris.data',)
kwargs = {'show_graph': True}
cls().run(*args, **kwargs)
if __name__ == '__main__':
MainApp.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment