Skip to content

Instantly share code, notes, and snippets.

@tushuhei
Created September 4, 2017 02:37
Show Gist options
  • Save tushuhei/e96aa25969d845cf63c000e13c872bc3 to your computer and use it in GitHub Desktop.
Save tushuhei/e96aa25969d845cf63c000e13c872bc3 to your computer and use it in GitHub Desktop.
# coding: utf-8
#
# Simulation program for unplugged neural network model.
# Reference:
# https://www.1101.com/morikawa/2001-03-12.html
# https://www.1101.com/morikawa/2001-04-02.html
import itertools
import random
import numpy as np
random.seed(1)
np.random.seed(1)
prices = np.array([310, 220, 70])
max_choice = 3
max_price = 600
data = [np.array(i) for i in itertools.product(range(max_choice), repeat=len(prices))][::-1]
random.shuffle(data)
train_data = data[:5]
test_data = data[5:]
iters = 3
nodes = np.array([5, 5, 5])
thres = 6
learn_rate = iters
for choice in test_data:
print 'CHOICE: %s,\tSUM_PRICE: %.2f,\tSUM_WEIGHT:%d\tTHRES:%d\tNodes:%s' % (
choice, np.dot(prices, choice), np.dot(nodes, choice), thres, nodes)
if np.dot(nodes, choice) > thres:
print 'FIRE',
if np.dot(prices, choice) > max_price:
print 'ok'
else:
print 'ng. too sensitive.'
else:
print 'STAY',
if np.dot(prices, choice) > max_price:
print 'ng. too insensitive.'
else:
print 'ok'
print '>>> TRAINING START <<<'
for iteration in range(iters):
for choice in train_data:
print 'CHOICE: %s,\tSUM_PRICE: %.2f,\tSUM_WEIGHT:%d\tTHRES:%d\tNodes:%s' % (
choice, np.dot(prices, choice), np.dot(nodes, choice), thres, nodes)
if np.dot(nodes, choice) > thres:
print 'FIRE',
if np.dot(prices, choice) > max_price:
print 'ok'
else:
print 'ng. too sensitive.'
thres += learn_rate
nodes[choice > 0] -= learn_rate
else:
print 'STAY',
if np.dot(prices, choice) > max_price:
print 'ng. too insensitive.'
thres -= learn_rate
nodes[choice > 0] += learn_rate
else:
print 'ok'
nodes = np.clip(nodes, 0, 10)
learn_rate -= 1
print '>>> TRAINING DONE <<<'
for choice in test_data:
print 'CHOICE: %s,\tSUM_PRICE: %.2f,\tSUM_WEIGHT:%d\tTHRES:%d\tNodes:%s' % (
choice, np.dot(prices, choice), np.dot(nodes, choice), thres, nodes)
if np.dot(nodes, choice) > thres:
print 'FIRE',
if np.dot(prices, choice) > max_price:
print 'ok'
else:
print 'ng. too sensitive.'
else:
print 'STAY',
if np.dot(prices, choice) > max_price:
print 'ng. too insensitive.'
else:
print 'ok'
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment