Created
January 6, 2018 23:24
-
-
Save gurimusan/de9704dc524c7977fca9bfbfbe7eb43e to your computer and use it in GitHub Desktop.
ミニバッチ勾配降下法
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
import re | |
import urllib2 | |
import random | |
import numpy | |
def _batch_indexes(indexes, batch_size): | |
for i in range(0, len(indexes), batch_size): | |
yield indexes[i:i+batch_size] | |
def mini_batch_gradient_descent(X, y, initial_theta, alpha, num_iters=1500, | |
batch_size=100): | |
u"""データセットに対してミニバッチ勾配降下法を実行し | |
目的関数を最小化するθを求める。 | |
:param numpy.ndarray X: 説明変数Xのベクトル | |
:param numpy.ndarray y: 結果Yのベクトル | |
:param numpy.ndarray initial_theta: θの初期値 | |
:param float alpha: 学習率 | |
:param int num_iters: 繰り返し回数 | |
:param int batch_size: バッチサイズ | |
:return: 目的関数を最小化するθ | |
""" | |
theta = numpy.copy(initial_theta) | |
xp = numpy.copy(X) | |
xp = numpy.insert(xp, 0, 1, axis=1) | |
for i in xrange(num_iters): | |
indexes = list(range(len(xp))) | |
random.shuffle(indexes) | |
for cur_indexes in _batch_indexes(indexes, batch_size): | |
_xp = xp[cur_indexes, :] | |
_y = y[cur_indexes, :] | |
grad = (1.0/_xp.shape[0]) * numpy.sum((_xp.dot(theta) - _y)*_xp, | |
axis=0, | |
keepdims=True) | |
theta = theta - (alpha * grad).T | |
return theta | |
if __name__ == '__main__': | |
datasrc = 'https://archive.ics.uci.edu/ml/machine-learning-databases/housing/housing.data' | |
data = urllib2.urlopen(datasrc).read().replace('\\n', '').splitlines() | |
data = numpy.array([[float(v) for v in re.split(" +", row.strip())] | |
for row in data]) | |
X = data[:, :-1] | |
Y = data[:, -1:] | |
initial_theta = numpy.zeros((data.shape[1], 1)) | |
alpha = 0.01 | |
num_iters = 1500 | |
# Normalize | |
X = (X - numpy.mean(X, axis=0)) / numpy.std(X, axis=0) | |
print mini_batch_gradient_descent(X, Y, initial_theta, alpha, num_iters) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment