Created
November 7, 2016 08:34
-
-
Save enakai00/3afcf8dbc10237c62b898f7852fbcb6e to your computer and use it in GitHub Desktop.
Reinforcement learning example for mini-max method Reversi.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# 強化学習によるリバーシの思考ルーチン作成 (Part1)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"このチュートリアルでは、強化学習によるリバーシ(オセロゲーム)の思考ルーチン作成を行います。\n", | |
"\n", | |
"リバーシの思考ルーチンでは、盤面の「評価関数」を用いたミニマックス法のアルゴリズムがよく知られています。ここでは、盤面の評価関数を強化学習で作成します。\n", | |
"\n", | |
"※ 以下の解説では、「自分が打つべき手を決める方法」を説明していますが、実際には、これがコンピューターの思考ルーチンに置き換わります。" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## ミニマックス法について" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### 評価関数\n", | |
"\n", | |
"自分の手番の盤面について、その盤面の評価値を計算する「評価関数」が存在するものとします。盤面の評価値は「その盤面が自分にとってどの程度有利な状況か」を表します。\n", | |
"\n", | |
"### 相手の手の予測\n", | |
"\n", | |
"自分の手番が来たら、自分が打てる手のそれぞれについて、その手を打った後、次に相手がどのような手を打つかを評価関数を用いて予測します。\n", | |
"\n", | |
"具体的には、相手は自分が不利になる手を打つはずですので、相手は、可能な手の中から、その手を打った後の評価値がもっとも小さくなるものを選択すると予測します。\n", | |
"\n", | |
"### 自分の手の決定\n", | |
"\n", | |
"前述の方法で相手の手を予測することで、自分が打てるそれぞれの手について、その手を打った場合の次の手番における盤面が決まります。このようにして決まる、次の盤面の評価値が最大になる手を選択します。\n", | |
"\n", | |
"このアルゴリズムは、「相手が評価値を最小とする手を選択した際に、その中でも評価値が最大となるものを選択する」という意味で「ミニマックス法」と呼ばれています。ミニマックス法では、評価関数をどのように設定するかによって、実際に選択される手が大きく変わります。" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## 評価関数の求め方" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### 大雑把な定義方法\n", | |
"\n", | |
"強化学習では、その盤面から状態を進めていった際に最終的に得られる「Reward」の合計をその盤面の評価値とします。Rewardの定義には任意性がありますが、ここでは、最終的な勝ち点(最終状態での「自分のコマ数 - 相手のコマ数」)を得られるRewardとします。\n", | |
"\n", | |
"もう少し正確には、次のようにして、あらゆる盤面の評価値を帰納的に定義します。\n", | |
"\n", | |
"ここでは、自分が先行と仮定します。また、γを0.8程度の1より少し小さい値とします。\n", | |
"\n", | |
"まず、残り2手の状態で、ミニマックス法に従って自分が打った後、相手が打って勝敗が決定する盤面を考えます。この盤面については、最終的な勝ち点をその盤面の評価値とします。(途中でパスが入るなどして、自分の手番で終了する場合は、最後に相手がパスをして終了するものと考えます。)\n", | |
"\n", | |
"その1つ前の自分の手番では、ミニマックス法に従って決まる次の盤面の評価値を V として、γV をその盤面の評価値とします。(「あと一歩で評価値 V の最終盤面に達成する」という意味で、γ 倍しています。)\n", | |
"\n", | |
"一般に、ある盤面の評価値は、ミニマックス法に従って決まる次の盤面の評価値 V を用いて、γV として、帰納的に定義することが可能になります。\n", | |
"\n", | |
"### 数学的な表現\n", | |
"\n", | |
"前述のようにして評価関数 V が決まったとすると、これは次の性質を満たします。\n", | |
"\n", | |
"最終盤面以外の盤面 s に対して、ミニマックス法で決まる次の盤面を s' とすると:\n", | |
"\n", | |
" V(s) = γV(s') −−− (1)\n", | |
"\n", | |
"最終盤面 s'' に対して、その盤面の評価値(「自分のコマ数 - 相手のコマ数」)を r(s'') とすると:\n", | |
"\n", | |
" V(s'') = r(s'') −−− (2)\n", | |
" \n", | |
"### ニューラルネットワークによる評価関数の近似表現\n", | |
"\n", | |
"盤面の状態は相当な数になるため、すべての盤面について個別に評価値を記録することは困難です。ここでは、盤面の状態を入力すると評価値が出力されるニューラルネットワークを用意して、評価関数を近似的に表現します。\n", | |
"\n", | |
"### モンテカルロ法によるニューラルネットワークの学習\n", | |
"\n", | |
"ニューラルネットワークで評価関数を再現するためには、あらゆる盤面 s に対して、(1)(2)を満たす関数を目指す必要があります。とはいえ、本当にあらゆる盤面を評価することはできませんので、ここでは、コンピューター同士の対戦データによって盤面のサンプルを用意します。\n", | |
"\n", | |
"まず、現在の評価関数を利用したミニマックス法を用いて、コンピューター同士で対戦させて、1回分の対戦データを生成します。この対戦における自分の手番の盤面の変化(s(1)→s(2)→・・・→s(n)) と最終的な勝ち点 S が決まると、そこに含まれる盤面について、あるべき評価値が次式で決まります。\n", | |
"\n", | |
" V(s(n)) = S\n", | |
" V(s(n-1)) = γS\n", | |
" V(s(n-2)) = γV(s(n-1))\n", | |
" ・・・\n", | |
" \n", | |
"そこで、ニューラルネットワークが出力する現在の評価値と上記で計算される評価値の2乗誤差を定義して、これが小さくなる方向にニューラルネットワークのパラメーターを調整します。パラメーター調整後の評価関数を用いて、再度、コンピューター同士で対戦して・・・という事を繰り返すことで、ニューラルネットワークは本来の評価関数に近づいていくものと期待されます。\n", | |
"\n", | |
"ここでコンピューター同士で対戦させる際、既存の他のアルゴリズムと対戦させることにより、そのアルゴリズムよりもさらに優秀な評価関数が生成できる可能性が考えられます。ここでは、対戦相手として、その盤面における「自分のコマ数 − 相手のコマ数」を評価関数とするミニマックス法を使用します。これは、目先のコマ数をなるべく増やそうとするアルゴリズムになります。\n", | |
"\n", | |
"より具体的な実装、および、シュミレーション方法については、この後のコードと一緒に解説を進めます。" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## 初期設定" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"必要なモジュールをインポートします。また、board_sizeには、盤面の1辺のマス数を与えます。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"import tensorflow as tf\n", | |
"import numpy as np\n", | |
"import random, math, copy, os\n", | |
"board_size = 8" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## ニューラルネットワークの定義" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"ニューラルネットワークのスーパークラスを定義します。\n", | |
"\n", | |
"get_values()メソッドは、盤面を表す2次元リストから評価値を計算します。正確には、盤面のリストを受け取って、評価値のリストを返します。1つの盤面は、1次元リストにフラット化した状態で渡します。また、評価値のリストに含まれる各評価値は、1要素のみのリストです。(つまり、評価値のリストは、[[1],[3],[2],・・・] という形式。)\n", | |
"\n", | |
"update_model()メソッドは、複数の対戦記録を含むリストを用いて、ニューラルネットワークのパラメーターを更新します。1回の対戦記録は、自分の手番の盤面のリストで、最後の要素に勝ち点が入っています。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"class Qnet(object): # super class\n", | |
" def __init__(self):\n", | |
" pass\n", | |
"\n", | |
" def get_values(self, boards):\n", | |
" xs = [sum(board, []) for board in boards]\n", | |
" values = self.q.eval(session=self.sess, feed_dict={self.x: xs})\n", | |
" return values\n", | |
"\n", | |
" # Montecalro update\n", | |
" def update_model(self, episodes):\n", | |
" gamma = 0.8\n", | |
" boards, targets = [], []\n", | |
" for episode in episodes:\n", | |
" localepisode = copy.deepcopy(episode)\n", | |
" targets.append([localepisode.pop()]) # Final reward\n", | |
" boards.append(sum(localepisode.pop(), [])) # Last board\n", | |
" while (localepisode):\n", | |
" boards.append(sum(localepisode.pop(), []))\n", | |
" targets.append([gamma * targets[-1][0]])\n", | |
" self.sess.run(self.train_step, feed_dict={self.x:boards, self.y_:targets})" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"ニューラルネットワークの実体となるクラスを定義します。\n", | |
"\n", | |
"ここでは、4x4の畳み込みフィルター32枚(+ReLUによるカットオフ)を2層分適用した後、512ユニットの全結合層を適用します。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"class DoubleCNN(Qnet):\n", | |
" def __init__(self):\n", | |
"\n", | |
" x = tf.placeholder(tf.float32, [None, board_size**2])\n", | |
" x_image = tf.reshape(x, [-1,board_size,board_size,1])\n", | |
"\n", | |
" num_filters1 = 32\n", | |
" W_conv1 = tf.Variable(tf.truncated_normal([4,4,1,num_filters1], stddev=0.1))\n", | |
" h_conv1 = tf.nn.conv2d(x_image, W_conv1, strides=[1,1,1,1], padding='SAME')\n", | |
" b_conv1 = tf.Variable(tf.constant(0.1, shape=[num_filters1]))\n", | |
" h_conv1_cutoff = tf.nn.relu(h_conv1 + b_conv1)\n", | |
"\n", | |
" num_filters2 = 32\n", | |
" W_conv2 = tf.Variable(tf.truncated_normal([4,4,num_filters1,num_filters2], stddev=0.1))\n", | |
" h_conv2 = tf.nn.conv2d(h_conv1_cutoff, W_conv2, strides=[1,1,1,1], padding='SAME')\n", | |
" b_conv2 = tf.Variable(tf.constant(0.1, shape=[num_filters2]))\n", | |
" h_conv2_cutoff = tf.nn.relu(h_conv2 + b_conv2)\n", | |
" \n", | |
" h_flat = tf.reshape(h_conv2_cutoff, [-1, (board_size**2)*num_filters2])\n", | |
" w1 = tf.Variable(tf.truncated_normal([(board_size**2)*num_filters2, 512],\n", | |
" stddev=1.0/math.sqrt(2.0)))\n", | |
" b1 = tf.Variable(tf.zeros([512]))\n", | |
" hidden1 = tf.nn.relu(tf.matmul(h_flat, w1) + b1)\n", | |
" \n", | |
" w = tf.Variable(tf.truncated_normal([512, 1], stddev=1.0/math.sqrt(2.0)))\n", | |
" b = tf.Variable(tf.zeros([1]))\n", | |
" q = tf.matmul(hidden1, w) + b\n", | |
"\n", | |
" y_ = tf.placeholder(tf.float32, [None, 1])\n", | |
" loss = tf.reduce_mean(tf.square(y_ - q))\n", | |
" train_step = tf.train.AdamOptimizer().minimize(loss)\n", | |
"\n", | |
" self.x = x\n", | |
" self.q = q\n", | |
" self.y_ = y_\n", | |
" self.loss = loss\n", | |
" self.train_step = train_step\n", | |
" \n", | |
" self.sess = tf.Session()\n", | |
" self.sess.run(tf.initialize_all_variables())" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## リバーシのゲーム環境の定義" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"リバーシのゲーム環境を提供するクラスを定義します。\n", | |
"\n", | |
"インスタンス作成時に board オプションで最初の盤面を指定することができます。盤面は、8x8の2次元リストで、各要素は以下の数字を含みます。\n", | |
"\n", | |
"- 0 : 空き\n", | |
"- 1 : 先手の石\n", | |
"- -1 : 後手の石\n", | |
" \n", | |
"次のようなメソッドを提供します。playerは ±1 で先手と後手を表します。\n", | |
"\n", | |
"- show_board() : その時点の盤面を表示します。\n", | |
"- is_allowed((x,y), player) : 現在の盤面で (x,y) に player の石が置けるか判定します。\n", | |
"- get_candidates(player) : \n", | |
"- get_reward(player) : playerから見た時の勝ち点を返します。盤面がまだゲーム途中の場合は0、引き分けの場合は0.1を返します(この返り値が0以外の場合に、ゲーム終了と判定するため)。\n", | |
"- update_state((x,y), player) : (x,y) に player の石を置いて、返した石の数を返します。(置けない場所を指定した場合は、何もせずに0を返します。)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"class Environ:\n", | |
" def __init__(self, board=None):\n", | |
" if board:\n", | |
" self.board = board\n", | |
" else:\n", | |
" self.board = [[0 for x in range(board_size)] for y in range(board_size)]\n", | |
" self.board[3][3] = 1\n", | |
" self.board[4][4] = 1\n", | |
" self.board[3][4] = -1\n", | |
" self.board[4][3] = -1\n", | |
"\n", | |
" def show_board(self):\n", | |
" stone = {'0': '-', '1': 'o', '-1': 'x'}\n", | |
" print '0 1 2 3 4 5 6 7'\n", | |
" for y in range(board_size):\n", | |
" for x in range(board_size):\n", | |
" print stone[str(self.board[y][x])],\n", | |
" print '%d' % y\n", | |
" print ''\n", | |
"\n", | |
" def is_allowed(self, (x,y), player):\n", | |
" board = self.board\n", | |
" opponent = -player\n", | |
" if board[y][x] != 0:\n", | |
" return False\n", | |
" directions = [(dx,dy) for dx in [-1,0,1] for dy in [-1,0,1] if not (dx == 0 and dy == 0)]\n", | |
" for direction in directions:\n", | |
" posx, posy = x, y\n", | |
" state = 0\n", | |
" while (True):\n", | |
" posx += direction[0]\n", | |
" posy += direction[1]\n", | |
" if posx < 0 or posx >= board_size or posy < 0 or posy >= board_size:\n", | |
" break\n", | |
" cell = board[posy][posx]\n", | |
" if state == 0:\n", | |
" if cell == opponent:\n", | |
" state = 1\n", | |
" continue\n", | |
" else:\n", | |
" break\n", | |
" if state == 1:\n", | |
" if cell == player:\n", | |
" return True\n", | |
" elif cell != opponent:\n", | |
" break\n", | |
" return False\n", | |
" \n", | |
" def get_reward(self, player):\n", | |
" board = self.board\n", | |
" opponent = -player\n", | |
"\n", | |
" finish = True\n", | |
" grid = [(x,y) for y in range(board_size) for x in range(board_size)]\n", | |
" for (x,y) in grid:\n", | |
" if self.is_allowed((x,y), player=1) or self.is_allowed((x,y), player=-1):\n", | |
" finish = False\n", | |
" break\n", | |
" if not finish:\n", | |
" return 0\n", | |
" score_p = len([True for y in range(board_size)\n", | |
" for x in range(board_size)\n", | |
" if board[y][x] == player])\n", | |
" score_o = len([True for y in range(board_size)\n", | |
" for x in range(board_size)\n", | |
" if board[y][x] == opponent])\n", | |
" if score_p == score_o:\n", | |
" return 0.1\n", | |
" return score_p - score_o\n", | |
" \n", | |
" def get_candidates(self, player):\n", | |
" candidates = []\n", | |
" grid = [(x,y) for y in range(board_size) for x in range(board_size)]\n", | |
" for (x,y) in grid:\n", | |
" if self.is_allowed((x,y), player=player):\n", | |
" candidates.append((x,y))\n", | |
" return candidates\n", | |
" \n", | |
" def update_state(self, (x, y), player):\n", | |
" if not self.is_allowed((x,y), player):\n", | |
" return None\n", | |
" self.board[y][x] = player\n", | |
" board = self.board\n", | |
" opponent = -player\n", | |
" score = 0\n", | |
" directions = [(dx,dy) for dx in [-1,0,1] for dy in [-1,0,1] if not (dx == 0 and dy == 0)]\n", | |
"\n", | |
" for direction in directions:\n", | |
" posx, posy = x, y\n", | |
" state = 0\n", | |
" candidates = []\n", | |
" while (True):\n", | |
" posx += direction[0]\n", | |
" posy += direction[1]\n", | |
" if posx < 0 or posx >= board_size or posy < 0 or posy >= board_size:\n", | |
" break\n", | |
" cell = board[posy][posx]\n", | |
" if state == 0:\n", | |
" if cell == opponent:\n", | |
" state = 1\n", | |
" candidates.append((posx,posy))\n", | |
" continue\n", | |
" else:\n", | |
" break\n", | |
" if state == 1:\n", | |
" if cell == player:\n", | |
" for (x0, y0) in candidates:\n", | |
" self.board[y0][x0] = player\n", | |
" score += 1\n", | |
" break\n", | |
" elif cell == opponent:\n", | |
" candidates.append((posx,posy))\n", | |
" continue\n", | |
" else:\n", | |
" break\n", | |
" return score " | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### ゲーム環境の動作確認" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"ゲーム環境のインスタンスを作成して、盤面の状態を表示します。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"0 1 2 3 4 5 6 7\n", | |
"- - - - - - - - 0\n", | |
"- - - - - - - - 1\n", | |
"- - - - - - - - 2\n", | |
"- - - o x - - - 3\n", | |
"- - - x o - - - 4\n", | |
"- - - - - - - - 5\n", | |
"- - - - - - - - 6\n", | |
"- - - - - - - - 7\n", | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"env = Environ()\n", | |
"env.show_board()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"player=1(先手)のコマを置いて、正しく動作することを確認します。この後、pleayer=-1(後手)のコマもおいて、ゲームが進行できることを確認してください。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"0 1 2 3 4 5 6 7\n", | |
"- - - - - - - - 0\n", | |
"- - - - - - - - 1\n", | |
"- - - - o - - - 2\n", | |
"- - - o o - - - 3\n", | |
"- - - x o - - - 4\n", | |
"- - - - - - - - 5\n", | |
"- - - - - - - - 6\n", | |
"- - - - - - - - 7\n", | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"env.update_state((4,2), player=1)\n", | |
"env.show_board()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## ミニマックス法の定義" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"ミニマックス法を用いて、次に打つ手を決定するエージェントのクラスを定義します。\n", | |
"\n", | |
"get_action()メソッドに、qnet(ニューラルネットワーク)、ゲーム環境(env)、手番(player)を与えると次に打つ手 (x,y) が返ります。\n", | |
"\n", | |
"この際、trainオプションをTrueに設定すると、epsilonで指定した割合でランダムな手を返します。これは、さまざまな盤面のデータを収集するために、学習中にあえてランダムな手を混ぜるために利用します。これは「ε-greedyポリシー」と呼ばれる手法になります。\n", | |
"\n", | |
"また、インスタンス変数episodesに過去の対戦記録を保存することが可能です。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"class Agent:\n", | |
" def __init__(self):\n", | |
" self.episodes = []\n", | |
" \n", | |
" # mini-max method\n", | |
" def get_action(self, qnet, env, player, train=False, epsilon=0):\n", | |
" candidates = env.get_candidates(player)\n", | |
" if len(candidates) == 0:\n", | |
" return None\n", | |
" if train and np.random.random() < epsilon:\n", | |
" random.shuffle(candidates)\n", | |
" return candidates[0]\n", | |
"\n", | |
" next_boards = []\n", | |
" for (x,y) in candidates:\n", | |
" localenv = Environ(board=copy.deepcopy(env.board))\n", | |
" localenv.update_state((x,y), player)\n", | |
" next_boards.append(copy.deepcopy(localenv.board))\n", | |
"\n", | |
" # Estimate opponents move for each candidate\n", | |
" values = []\n", | |
" for next_board in next_boards:\n", | |
" localenv = Environ(board=copy.deepcopy(next_board))\n", | |
" next_candidates = localenv.get_candidates(-player)\n", | |
" if len(next_candidates) == 0:\n", | |
" # the opponent to pass.\n", | |
" # Note that the meaning of 'value' depends on the player\n", | |
" values.append(player * qnet.get_values([next_board])[0][0])\n", | |
" continue\n", | |
" next_next_boards = []\n", | |
" for (x,y) in next_candidates:\n", | |
" localenv = Environ(board=copy.deepcopy(next_board))\n", | |
" localenv.update_state((x,y), -player)\n", | |
" next_next_boards.append(copy.deepcopy(localenv.board))\n", | |
" # Note that the meaning of 'value' depends on the player\n", | |
" value = min(player * qnet.get_values(next_next_boards))[0]\n", | |
" values.append(value)\n", | |
"\n", | |
" action = candidates[np.argmax(values)]\n", | |
" return action" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## 対戦のシュミレーション" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"その盤面における「自分のコマ数 − 相手のコマ数」を返す評価関数を持った、ダミーのニューラルネットワークを定義します。これを対戦相手として利用します。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"class SimpleMiniMax:\n", | |
" def get_values(self, boards):\n", | |
" result = []\n", | |
" for board in boards:\n", | |
" score = 0.0\n", | |
" # The value is defined from the player=1's point of view.\n", | |
" for c in sum(board, []):\n", | |
" if c == 1: score += 1\n", | |
" if c == -1: score -= 1\n", | |
" result.append([score])\n", | |
" return np.array(result)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"続いて、1回分のコンピューター同士の対戦をシュミレーションする関数を定義します。\n", | |
"\n", | |
"先手は、学習対象のAIで、前述のようにε-greedyポリシーを適用します。epsilonの値は、大きな値から小さな値へと変化させながら学習を進めます。(trainオプションをFalseに設定すると、epsilon=0(ランダムな手は混ぜない)に設定されます。)\n", | |
"\n", | |
"後手は、ダミーのニューラルネットワークを用いたミニマックス法でシュミレーションしますが、epsilonの値は0.1に固定しています。(対戦相手の手にランダムな要素が無いと、その相手に勝つためだけのオーバーフィッティングが発生するためです。)\n", | |
"\n", | |
"trainオプションがTrueの場合、対戦が終わるとその記録をエージェントのepisodesに保存して、ニューラルネットワークのパラメーター修正を実施します。過去3回分の記録を残しておき、それらをまとめてバッチ適用します。つまり、1回の対戦記録は、パラメーター修正に3回利用されることになります。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"def run(qnet, agent, train, epsilon):\n", | |
" env = Environ()\n", | |
" log = []\n", | |
" episode_record = []\n", | |
" simple_minimax = SimpleMiniMax()\n", | |
"\n", | |
" while (True):\n", | |
" # AI Player\n", | |
" episode_record.append(copy.deepcopy(env.board))\n", | |
" action = agent.get_action(qnet, env, player=1, train=True, epsilon=epsilon)\n", | |
" if action:\n", | |
" env.update_state(action, player=1)\n", | |
" log.append(copy.deepcopy(env.board))\n", | |
" \n", | |
" # Human Player\n", | |
" action = agent.get_action(simple_minimax, env, player=-1, train=True, epsilon=0.1)\n", | |
" if action:\n", | |
" env.update_state(action, player=-1)\n", | |
" log.append(copy.deepcopy(env.board))\n", | |
" reward = env.get_reward(player=1)\n", | |
" if reward != 0:\n", | |
" episode_record.append(reward) # Episode ends with the reward value.\n", | |
" break\n", | |
" \n", | |
" if train:\n", | |
" agent.episodes.append(episode_record)\n", | |
" if len(agent.episodes) > 3:\n", | |
" agent.episodes = agent.episodes[1:]\n", | |
" qnet.update_model(agent.episodes)\n", | |
"\n", | |
" return reward, log" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## モンテカルロシュミレーションの実施" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"コンピューター同士の対戦のシュミレーションを繰り返して、対戦データを用いたチューニングを行います。epsilonを変化させる条件は次の通りです。\n", | |
"\n", | |
"- 先手は epsilon=0.9 から開始して、30回対戦するごとに0.9倍に減少させていきます。最終的に1500回対戦して、epsilon≒0.05 になります。\n", | |
"- 後手は epsilon=0.1 に固定します。\n", | |
"\n", | |
"1回対戦するごとに勝敗を記号(先手の勝ち、もしくは引分けは「+」、先手の負けは「-」)で示します。10回対戦するごとに、epsilon=0 に設定した「本気モード」で対戦してその結果を表示した上で、その時点のパラメーターをディレクトリー「ReversiData01」以下に保存します。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": { | |
"collapsed": false, | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"- - - + - - - - + - - - - - - - - - + - + - + - - - + - + - step 30, epsilon 0.900, reward 6\n", | |
"+ + + - + + - - - + + - + - - - - + - + - - - + + - + - - - step 60, epsilon 0.810, reward -20\n", | |
"+ - - - - - - + + - - - + - - + + + - - - - + - - - - - - + step 90, epsilon 0.729, reward -6\n", | |
"- - + + - - + - - + - - + - + + - + - - - - - - - - - - - - step 120, epsilon 0.656, reward -8\n", | |
"- - - + + + - - + - - + + - - - - + - + - + - - - - - + + - step 150, epsilon 0.590, reward 40\n", | |
"- - - + - - + + + - - - - + - - - + - - - - - - - - - + - + step 180, epsilon 0.531, reward -28\n", | |
"- - - + - - - + - - + - - + + - + - + + + + - - - - - - - - step 210, epsilon 0.478, reward -16\n", | |
"- - + + + - + - + + - - + + + + - + - + - - + + + - - + - - step 240, epsilon 0.430, reward -12\n", | |
"- + - - - + - + - - + - - + - - - + - - - - + - + + - + - - step 270, epsilon 0.387, reward -18\n", | |
"- - + - - + - - - - + - + - - + - - - - - + + - - + + + - - step 300, epsilon 0.349, reward -6\n", | |
"+ - - + - - - + + - - - - - - + - + + + - - - - - - - - - - step 330, epsilon 0.314, reward -10\n", | |
"+ - - - - + - - + + - - - - - - + + - - + - - + - - - - + + step 360, epsilon 0.282, reward -50\n", | |
"+ - + - + - + - + - + - + + - - - + - + - - - + - + + + - - step 390, epsilon 0.254, reward -20\n", | |
"+ - - + - + + + - - - - - - - - - - - + - - - + - - + - - - step 420, epsilon 0.229, reward -32\n", | |
"- - + + - - - + - + - + - - - - - - - + - + - + + - - + + + step 450, epsilon 0.206, reward -14\n", | |
"+ + + - + - + + + - - + - - + + + + + - - - - + + - - - + + step 480, epsilon 0.185, reward 22\n", | |
"+ - + + + - + - + + + - + + - + - + + - - - + + - - - - - + step 510, epsilon 0.167, reward 16\n", | |
"+ - - + - - + + - + - + - - + + + + - - - - - - - + + - + + step 540, epsilon 0.150, reward -14\n", | |
"- + - - - - - - - - + - - + - + + - + - + + - + - + + + - + step 570, epsilon 0.135, reward 34\n", | |
"- - - + + + + + + + + + + + - + + + + - + + + + - + - - + + step 600, epsilon 0.122, reward 26\n", | |
"+ - - - + + + + + - - + + - - + + - + + + + + - + + + - - - step 630, epsilon 0.109, reward -48\n", | |
"- + - + + - - - + - + - - + - - + - + + - - + + + - - - + - step 660, epsilon 0.098, reward 20\n", | |
"+ - - + - - - - - - - - - - - - - - - - + - - + - + + + - - step 690, epsilon 0.089, reward 10\n", | |
"+ - - + - - + + + - - + - - - + - - + + - + + - - - - - + - step 720, epsilon 0.080, reward -18\n", | |
"+ - - - - + + + - + + - + + - - - - - - - - + - + + + - - + step 750, epsilon 0.072, reward 12\n", | |
"- + + - - + + - + - - + + + + - + + - + + + + - - + - - - - step 780, epsilon 0.065, reward -14\n", | |
"+ + + - - + - - + + + - + + + + - - - + - - + + + - - + + - step 810, epsilon 0.058, reward -4\n", | |
"- - + - - + - + + + - + + - - - - + + - - - + + - - + - + - step 840, epsilon 0.052, reward -22\n", | |
"+ + - + + + - - - - + + - + + + - - - - - + + - - - - - - + step 870, epsilon 0.047, reward -22\n", | |
"- - + - + + - - - + - - + - - + + + - + - - + + + + - + + - step 900, epsilon 0.042, reward -2\n", | |
"+ + + - + + - - - + + - + + + - + - + - - + - + + - + - + + step 930, epsilon 0.038, reward 8\n", | |
"- - - + - + - - + + + + - - + - - + + + + - + + - - - + + - step 960, epsilon 0.034, reward 26\n", | |
"- - - - + + - + - - + - + + + + + + + - + + + - + + + - + - step 990, epsilon 0.031, reward 8\n", | |
"- + - - + - + - + + + + + - + - - + + - - + + + + + + + - + step 1020, epsilon 0.028, reward -34\n", | |
"+ + + + + + + + + - + + - - + - - + - + - + + + - + + + + - step 1050, epsilon 0.025, reward -22\n", | |
"- - - - + - + + + + + - + + + + + + + + - + - + - + - + + + step 1080, epsilon 0.023, reward -34\n", | |
"- + + - - + - + + + + + + + + + - + + - + - + - - - - + - + step 1110, epsilon 0.020, reward -16\n", | |
"- - + + + + - - - + + + + - + + - + + + + + - + + - + + - - step 1140, epsilon 0.018, reward 12\n", | |
"+ + - + + - + - + - + + + + - + - - + + + + + + + + + + + - step 1170, epsilon 0.016, reward 4\n", | |
"+ + + + + + + + + - + + - + + + + + + - + + - - + + + - + - step 1200, epsilon 0.015, reward 8\n", | |
"+ - - - + + - + - - + - - + + + - + + - + + - + - - - - + + step 1230, epsilon 0.013, reward -8\n", | |
"+ + + - - + - - + + + - - + - + - + + + - + + + - - - - + + step 1260, epsilon 0.012, reward 6\n", | |
"+ - + - + + + - - - + + - + - + + - - + + + - + - - + + - - step 1290, epsilon 0.011, reward 16\n", | |
"- + + + - - + + + + + + - + + + + + + + - + + + - + - - - + step 1320, epsilon 0.010, reward -26\n", | |
"+ - + - + + + - + + + - - - - + + + - - + + - - + + - + - + step 1350, epsilon 0.009, reward -10\n", | |
"- + - + - - + + - - + - - - - + + - + - - + + - + + + + + + step 1380, epsilon 0.008, reward 30\n", | |
"+ + - + + + + + + + - + - - + - + + - + + - - - - - + + - + step 1410, epsilon 0.007, reward 34\n", | |
"+ + + + + - + + + - + - + + + + - + + + + + - + + + - - + + step 1440, epsilon 0.006, reward 10\n", | |
"+ + + - + - - - + + - - + + - - - - + + + + + + + - - + - + step 1470, epsilon 0.006, reward 14\n", | |
"+ - + + + - - - + + - + - - + + + + + + - + - + - + + - + + step 1500, epsilon 0.005, reward 0\n" | |
] | |
} | |
], | |
"source": [ | |
"label = 'ReversiData01'\n", | |
"try:\n", | |
" os.mkdir(label)\n", | |
"except OSError:\n", | |
" pass\n", | |
"qnet = DoubleCNN()\n", | |
"agent = Agent()\n", | |
"saver = tf.train.Saver()\n", | |
"\n", | |
"epsilon = 0.9\n", | |
"for i in range(1, 1501):\n", | |
" reward, log = run(qnet, agent, train=True, epsilon=epsilon)\n", | |
" if reward > 0:\n", | |
" print '+',\n", | |
" else:\n", | |
" print '-',\n", | |
" if i % 30 == 0:\n", | |
" saver.save(qnet.sess, label + '/train_data', global_step=i)\n", | |
" reward, log = run(qnet, agent, train=False, epsilon=0)\n", | |
" print 'step %d, epsilon %.3f, reward %d' % (i, epsilon, reward)\n", | |
" epsilon *= 0.9" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"この時点の学習結果を用いて、ダミーのニューラルネットワークを用いたミニマックス法と100回対戦して結果を確認します。(対戦相手の手にランダムな要素がないとすべて同じ結果になるので、対戦相手は epsilon=0.1 のままとしています。)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"+ + + + + + + + + + + + + + + - + + + - + + + - + + + + - + + + - - + + + - + + + + + + + + + + + + + + + + + + + + + + + + + - + + + + + + + - + + + + - + - + + + + + + + + - + + - - - + - - + + - - \n", | |
"Average score 17.2, Winning rate 81%\n" | |
] | |
} | |
], | |
"source": [ | |
"n = 100\n", | |
"score, winrate = 0.0, 0.0\n", | |
"for _ in range(n):\n", | |
" reward, log = run(qnet, agent, train=False, epsilon=0)\n", | |
" if reward > 0:\n", | |
" winrate += 1.0\n", | |
" score += reward\n", | |
" print '+',\n", | |
" else:\n", | |
" print '-',\n", | |
"print ''\n", | |
"print 'Average score %.1f, Winning rate %d%%' % (score/n, winrate*100/n)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"この例では、勝率は81%で、「自分のコマ数 − 相手のコマ数」を評価関数とする単純なミニマックス法よりも性能向上が確認できました。\n", | |
"\n", | |
"次のコードを実行すると、1回だけ対戦して盤面の変化を表示します。繰り返し実行して、それぞれの打ち手の傾向を観察してください。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": { | |
"collapsed": false, | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"0 1 2 3 4 5 6 7\n", | |
"- - - - - - - - 0\n", | |
"- - - - - - - - 1\n", | |
"- - - - - - - - 2\n", | |
"- - - o x - - - 3\n", | |
"- - - o o - - - 4\n", | |
"- - - o - - - - 5\n", | |
"- - - - - - - - 6\n", | |
"- - - - - - - - 7\n", | |
"\n", | |
"\n", | |
"0 1 2 3 4 5 6 7\n", | |
"- - - - - - - - 0\n", | |
"- - - - - - - - 1\n", | |
"- - - - - - - - 2\n", | |
"- - x x x - - - 3\n", | |
"- - - o o - - - 4\n", | |
"- - - o - - - - 5\n", | |
"- - - - - - - - 6\n", | |
"- - - - - - - - 7\n", | |
"\n", | |
"\n", | |
"0 1 2 3 4 5 6 7\n", | |
"- - - - - - - - 0\n", | |
"- - - - - - - - 1\n", | |
"- - - o - - - - 2\n", | |
"- - x o x - - - 3\n", | |
"- - - o o - - - 4\n", | |
"- - - o - - - - 5\n", | |
"- - - - - - - - 6\n", | |
"- - - - - - - - 7\n", | |
"\n", | |
"\n", | |
"0 1 2 3 4 5 6 7\n", | |
"- - - - - - - - 0\n", | |
"- - - - x - - - 1\n", | |
"- - - x - - - - 2\n", | |
"- - x o x - - - 3\n", | |
"- - - o o - - - 4\n", | |
"- - - o - - - - 5\n", | |
"- - - - - - - - 6\n", | |
"- - - - - - - - 7\n", | |
"\n", | |
"\n", | |
"0 1 2 3 4 5 6 7\n", | |
"- - - - - - - - 0\n", | |
"- - - - x - - - 1\n", | |
"- - - x - o - - 2\n", | |
"- - x o o - - - 3\n", | |
"- - - o o - - - 4\n", | |
"- - - o - - - - 5\n", | |
"- - - - - - - - 6\n", | |
"- - - - - - - - 7\n", | |
"\n", | |
"\n", | |
"0 1 2 3 4 5 6 7\n", | |
"- - - - - - - - 0\n", | |
"- - - - x - - - 1\n", | |
"- - - x - o - - 2\n", | |
"- - x x o - - - 3\n", | |
"- - - x o - - - 4\n", | |
"- - - x - - - - 5\n", | |
"- - - x - - - - 6\n", | |
"- - - - - - - - 7\n", | |
"\n", | |
"\n", | |
"0 1 2 3 4 5 6 7\n", | |
"- - - o - - - - 0\n", | |
"- - - - o - - - 1\n", | |
"- - - x - o - - 2\n", | |
"- - x x o - - - 3\n", | |
"- - - x o - - - 4\n", | |
"- - - x - - - - 5\n", | |
"- - - x - - - - 6\n", | |
"- - - - - - - - 7\n", | |
"\n", | |
"\n", | |
"0 1 2 3 4 5 6 7\n", | |
"- - - o - - - - 0\n", | |
"- - - - o - x - 1\n", | |
"- - - x - x - - 2\n", | |
"- - x x x - - - 3\n", | |
"- - - x o - - - 4\n", | |
"- - - x - - - - 5\n", | |
"- - - x - - - - 6\n", | |
"- - - - - - - - 7\n", | |
"\n", | |
"\n", | |
"0 1 2 3 4 5 6 7\n", | |
"- - - o - - - - 0\n", | |
"- - - - o - x - 1\n", | |
"- - - x - o - - 2\n", | |
"- - x x x - o - 3\n", | |
"- - - x o - - - 4\n", | |
"- - - x - - - - 5\n", | |
"- - - x - - - - 6\n", | |
"- - - - - - - - 7\n", | |
"\n", | |
"\n", | |
"0 1 2 3 4 5 6 7\n", | |
"- - - o - x - - 0\n", | |
"- - - - x - x - 1\n", | |
"- - - x - o - - 2\n", | |
"- - x x x - o - 3\n", | |
"- - - x o - - - 4\n", | |
"- - - x - - - - 5\n", | |
"- - - x - - - - 6\n", | |
"- - - - - - - - 7\n", | |
"\n", | |
"\n", | |
"0 1 2 3 4 5 6 7\n", | |
"- - - o - x - o 0\n", | |
"- - - - x - o - 1\n", | |
"- - - x - o - - 2\n", | |
"- - x x x - o - 3\n", | |
"- - - x o - - - 4\n", | |
"- - - x - - - - 5\n", | |
"- - - x - - - - 6\n", | |
"- - - - - - - - 7\n", | |
"\n", | |
"\n", | |
"0 1 2 3 4 5 6 7\n", | |
"- - - o - x - o 0\n", | |
"- - - - x - x - 1\n", | |
"- - - x - o - x 2\n", | |
"- - x x x - o - 3\n", | |
"- - - x o - - - 4\n", | |
"- - - x - - - - 5\n", | |
"- - - x - - - - 6\n", | |
"- - - - - - - - 7\n", | |
"\n", | |
"\n", | |
"0 1 2 3 4 5 6 7\n", | |
"- - - o - x - o 0\n", | |
"- - - - x - x - 1\n", | |
"- - - x - o - x 2\n", | |
"- - x x x - o - 3\n", | |
"- - o o o - - - 4\n", | |
"- - - x - - - - 5\n", | |
"- - - x - - - - 6\n", | |
"- - - - - - - - 7\n", | |
"\n", | |
"\n", | |
"0 1 2 3 4 5 6 7\n", | |
"- - - o - x - o 0\n", | |
"- - - - x - x - 1\n", | |
"- - - x - x - x 2\n", | |
"- - x x x - x - 3\n", | |
"- - o o o - - x 4\n", | |
"- - - x - - - - 5\n", | |
"- - - x - - - - 6\n", | |
"- - - - - - - - 7\n", | |
"\n", | |
"\n", | |
"0 1 2 3 4 5 6 7\n", | |
"- - - o - x - o 0\n", | |
"- - - - x - x - 1\n", | |
"- - - x - x - x 2\n", | |
"- - x x x - x - 3\n", | |
"- - o o o - - x 4\n", | |
"- - - o - - - - 5\n", | |
"- - o x - - - - 6\n", | |
"- - - - - - - - 7\n", | |
"\n", | |
"\n", | |
"0 1 2 3 4 5 6 7\n", | |
"- - - o - x - o 0\n", | |
"- - - - x - x - 1\n", | |
"- - - x - x - x 2\n", | |
"- - x x x - x - 3\n", | |
"- - x o o - - x 4\n", | |
"- x - o - - - - 5\n", | |
"- - o x - - - - 6\n", | |
"- - - - - - - - 7\n", | |
"\n", | |
"\n", | |
"0 1 2 3 4 5 6 7\n", | |
"- - - o - x - o 0\n", | |
"- - - - x - x - 1\n", | |
"- - - x - x - x 2\n", | |
"- - x x x - x - 3\n", | |
"- o o o o - - x 4\n", | |
"- x - o - - - - 5\n", | |
"- - o x - - - - 6\n", | |
"- - - - - - - - 7\n", | |
"\n", | |
"\n", | |
"0 1 2 3 4 5 6 7\n", | |
"- - - o - x - o 0\n", | |
"- - - - x - x - 1\n", | |
"- - - x - x - x 2\n", | |
"- x x x x - x - 3\n", | |
"- x o o o - - x 4\n", | |
"- x - o - - - - 5\n", | |
"- - o x - - - - 6\n", | |
"- - - - - - - - 7\n", | |
"\n", | |
"\n", | |
"0 1 2 3 4 5 6 7\n", | |
"- - - o - x - o 0\n", | |
"- - - - x - x - 1\n", | |
"- - - x o x - x 2\n", | |
"- x x o o - x - 3\n", | |
"- x o o o - - x 4\n", | |
"- x - o - - - - 5\n", | |
"- - o x - - - - 6\n", | |
"- - - - - - - - 7\n", | |
"\n", | |
"\n", | |
"0 1 2 3 4 5 6 7\n", | |
"- - - o - x - o 0\n", | |
"- - - - x x x - 1\n", | |
"- - - x x x - x 2\n", | |
"- x x x o - x - 3\n", | |
"- x x o o - - x 4\n", | |
"- x - o - - - - 5\n", | |
"- - o x - - - - 6\n", | |
"- - - - - - - - 7\n", | |
"\n", | |
"\n", | |
"0 1 2 3 4 5 6 7\n", | |
"- - - o - x - o 0\n", | |
"- - - - x x x - 1\n", | |
"- - - x x x - x 2\n", | |
"- x x x o - x - 3\n", | |
"- x x o o - - x 4\n", | |
"- x - o - - - - 5\n", | |
"- - o o - - - - 6\n", | |
"- - - o - - - - 7\n", | |
"\n", | |
"\n", | |
"0 1 2 3 4 5 6 7\n", | |
"- - - o - x - o 0\n", | |
"- - - - x x x - 1\n", | |
"- - - x x x - x 2\n", | |
"- x x x x - x - 3\n", | |
"- x x x x x - x 4\n", | |
"- x - o - - - - 5\n", | |
"- - o o - - - - 6\n", | |
"- - - o - - - - 7\n", | |
"\n", | |
"\n", | |
"0 1 2 3 4 5 6 7\n", | |
"- - - o - x - o 0\n", | |
"- - - - x x x - 1\n", | |
"o - - x x x - x 2\n", | |
"- o x x x - x - 3\n", | |
"- x o x x x - x 4\n", | |
"- x - o - - - - 5\n", | |
"- - o o - - - - 6\n", | |
"- - - o - - - - 7\n", | |
"\n", | |
"\n", | |
"0 1 2 3 4 5 6 7\n", | |
"- - - o - x - o 0\n", | |
"- - - - x x x - 1\n", | |
"o - - x x x - x 2\n", | |
"- o x x x - x - 3\n", | |
"- x o x x x - x 4\n", | |
"- x - x - - - - 5\n", | |
"- - x o - - - - 6\n", | |
"- x - o - - - - 7\n", | |
"\n", | |
"\n", | |
"0 1 2 3 4 5 6 7\n", | |
"- - - o - x - o 0\n", | |
"- - - - x x x - 1\n", | |
"o - - x x x - x 2\n", | |
"- o x x x - x - 3\n", | |
"- x o x x x - x 4\n", | |
"- o - x - - - - 5\n", | |
"o - x o - - - - 6\n", | |
"- x - o - - - - 7\n", | |
"\n", | |
"\n", | |
"0 1 2 3 4 5 6 7\n", | |
"- - - o - x - o 0\n", | |
"- - - - x x x - 1\n", | |
"o x - x x x - x 2\n", | |
"- x x x x - x - 3\n", | |
"- x o x x x - x 4\n", | |
"- o - x - - - - 5\n", | |
"o - x o - - - - 6\n", | |
"- x - o - - - - 7\n", | |
"\n", | |
"\n", | |
"0 1 2 3 4 5 6 7\n", | |
"- - - o - x - o 0\n", | |
"- - - - x x x - 1\n", | |
"o x - x x x - x 2\n", | |
"- x x x x - x - 3\n", | |
"- x o x x x - x 4\n", | |
"- o - x - - - - 5\n", | |
"o o o o - - - - 6\n", | |
"- x - o - - - - 7\n", | |
"\n", | |
"\n", | |
"0 1 2 3 4 5 6 7\n", | |
"- - - o - x - o 0\n", | |
"- - - - x x x - 1\n", | |
"o x - x x x - x 2\n", | |
"- x x x x - x - 3\n", | |
"- x x x x x - x 4\n", | |
"- o x x - - - - 5\n", | |
"o o o o - - - - 6\n", | |
"- x - o - - - - 7\n", | |
"\n", | |
"\n", | |
"0 1 2 3 4 5 6 7\n", | |
"- - - o - x - o 0\n", | |
"- - - - x x x - 1\n", | |
"o x - x x x - x 2\n", | |
"- o x x x - x - 3\n", | |
"- x o x x x - x 4\n", | |
"- o x o - - - - 5\n", | |
"o o o o o - - - 6\n", | |
"- x - o - - - - 7\n", | |
"\n", | |
"\n", | |
"0 1 2 3 4 5 6 7\n", | |
"- - - o - x - o 0\n", | |
"- - - - x x x - 1\n", | |
"o x - x x x - x 2\n", | |
"x x x x x - x - 3\n", | |
"- x o x x x - x 4\n", | |
"- o x o - - - - 5\n", | |
"o o o o o - - - 6\n", | |
"- x - o - - - - 7\n", | |
"\n", | |
"\n", | |
"0 1 2 3 4 5 6 7\n", | |
"- - - o - x - o 0\n", | |
"- - - - x x x - 1\n", | |
"o x - x x x - x 2\n", | |
"x x x x x o x - 3\n", | |
"- x o x o x - x 4\n", | |
"- o x o - - - - 5\n", | |
"o o o o o - - - 6\n", | |
"- x - o - - - - 7\n", | |
"\n", | |
"\n", | |
"0 1 2 3 4 5 6 7\n", | |
"- - - o - x - o 0\n", | |
"- - - - x x x - 1\n", | |
"o x - x x x x x 2\n", | |
"x x x x x x x - 3\n", | |
"- x o x x x - x 4\n", | |
"- o x x - - - - 5\n", | |
"o o x o o - - - 6\n", | |
"- x - o - - - - 7\n", | |
"\n", | |
"\n", | |
"0 1 2 3 4 5 6 7\n", | |
"- - - o - x - o 0\n", | |
"- - - - x x x - 1\n", | |
"o x - x x x x x 2\n", | |
"x x x x x x x - 3\n", | |
"- x o o o o o x 4\n", | |
"- o x x - - - - 5\n", | |
"o o x o o - - - 6\n", | |
"- x - o - - - - 7\n", | |
"\n", | |
"\n", | |
"0 1 2 3 4 5 6 7\n", | |
"- - - o - x - o 0\n", | |
"- - - - x x x - 1\n", | |
"o x - x x x x x 2\n", | |
"x x x x x x x - 3\n", | |
"- x o o o x x x 4\n", | |
"- o x x - - x - 5\n", | |
"o o x o o - - - 6\n", | |
"- x - o - - - - 7\n", | |
"\n", | |
"\n", | |
"0 1 2 3 4 5 6 7\n", | |
"- - - o - x - o 0\n", | |
"- o - - x x x - 1\n", | |
"o o - x x x x x 2\n", | |
"x o x x x x x - 3\n", | |
"- o o o o x x x 4\n", | |
"- o x x - - x - 5\n", | |
"o o x o o - - - 6\n", | |
"- x - o - - - - 7\n", | |
"\n", | |
"\n", | |
"0 1 2 3 4 5 6 7\n", | |
"- x - o - x - o 0\n", | |
"- x - - x x x - 1\n", | |
"o x - x x x x x 2\n", | |
"x x x x x x x - 3\n", | |
"- x o o o x x x 4\n", | |
"- x x x - - x - 5\n", | |
"o x x o o - - - 6\n", | |
"- x - o - - - - 7\n", | |
"\n", | |
"\n", | |
"0 1 2 3 4 5 6 7\n", | |
"- x - o - x - o 0\n", | |
"- x - - x x x - 1\n", | |
"o x - x x x x x 2\n", | |
"x x x x x x x - 3\n", | |
"- x o o o x x x 4\n", | |
"- x o x - - x - 5\n", | |
"o o x o o - - - 6\n", | |
"o x - o - - - - 7\n", | |
"\n", | |
"\n", | |
"0 1 2 3 4 5 6 7\n", | |
"- x - o - x - o 0\n", | |
"- x - - x x x - 1\n", | |
"o x - x x x x x 2\n", | |
"x x x x x x x - 3\n", | |
"- x o x x x x x 4\n", | |
"- x o x x - x - 5\n", | |
"o o x o o - - - 6\n", | |
"o x - o - - - - 7\n", | |
"\n", | |
"\n", | |
"0 1 2 3 4 5 6 7\n", | |
"- x - o - x - o 0\n", | |
"- x - - x x x - 1\n", | |
"o x - x x x x x 2\n", | |
"o x x x x x x - 3\n", | |
"o o o x x x x x 4\n", | |
"- o o x x - x - 5\n", | |
"o o o o o - - - 6\n", | |
"o x - o - - - - 7\n", | |
"\n", | |
"\n", | |
"0 1 2 3 4 5 6 7\n", | |
"- x - o - x - o 0\n", | |
"- x - - x x x - 1\n", | |
"o x - x x x x x 2\n", | |
"o x x x x x x - 3\n", | |
"o o x x x x x x 4\n", | |
"- o x x x - x - 5\n", | |
"o o x x o - - - 6\n", | |
"o x x o - - - - 7\n", | |
"\n", | |
"\n", | |
"0 1 2 3 4 5 6 7\n", | |
"- x - o - x - o 0\n", | |
"- x - o x x x - 1\n", | |
"o x - o x x x x 2\n", | |
"o x x o x x x - 3\n", | |
"o o x o x x x x 4\n", | |
"- o x o x - x - 5\n", | |
"o o x o o - - - 6\n", | |
"o x x o - - - - 7\n", | |
"\n", | |
"\n", | |
"0 1 2 3 4 5 6 7\n", | |
"- x - o - x - o 0\n", | |
"- x x x x x x - 1\n", | |
"o x - x x x x x 2\n", | |
"o x x o x x x - 3\n", | |
"o o x o x x x x 4\n", | |
"- o x o x - x - 5\n", | |
"o o x o o - - - 6\n", | |
"o x x o - - - - 7\n", | |
"\n", | |
"\n", | |
"0 1 2 3 4 5 6 7\n", | |
"- x - o o x - o 0\n", | |
"- x x x o x x - 1\n", | |
"o x - x o x x x 2\n", | |
"o x x o o x x - 3\n", | |
"o o x o o x x x 4\n", | |
"- o x o o - x - 5\n", | |
"o o x o o - - - 6\n", | |
"o x x o - - - - 7\n", | |
"\n", | |
"\n", | |
"0 1 2 3 4 5 6 7\n", | |
"- x x x x x - o 0\n", | |
"- x x x o x x - 1\n", | |
"o x - x o x x x 2\n", | |
"o x x o o x x - 3\n", | |
"o o x o o x x x 4\n", | |
"- o x o o - x - 5\n", | |
"o o x o o - - - 6\n", | |
"o x x o - - - - 7\n", | |
"\n", | |
"\n", | |
"0 1 2 3 4 5 6 7\n", | |
"- x x x x x - o 0\n", | |
"- x x x o x x - 1\n", | |
"o x - x o x x x 2\n", | |
"o x x o o x x - 3\n", | |
"o o x o o o x x 4\n", | |
"- o x o o - o - 5\n", | |
"o o x o o - - o 6\n", | |
"o x x o - - - - 7\n", | |
"\n", | |
"\n", | |
"0 1 2 3 4 5 6 7\n", | |
"- x x x x x - o 0\n", | |
"- x x x x x x - 1\n", | |
"o x - x x x x x 2\n", | |
"o x x o x x x - 3\n", | |
"o o x o x o x x 4\n", | |
"- o x o x - o - 5\n", | |
"o o x x x - - o 6\n", | |
"o x x x x - - - 7\n", | |
"\n", | |
"\n", | |
"0 1 2 3 4 5 6 7\n", | |
"- x x x x x - o 0\n", | |
"o x x x x x x - 1\n", | |
"o o - x x x x x 2\n", | |
"o x o o x x x - 3\n", | |
"o o x o x o x x 4\n", | |
"- o x o x - o - 5\n", | |
"o o x x x - - o 6\n", | |
"o x x x x - - - 7\n", | |
"\n", | |
"\n", | |
"0 1 2 3 4 5 6 7\n", | |
"- x x x x x - o 0\n", | |
"o x x x x x x - 1\n", | |
"o o - x x x x x 2\n", | |
"o x x o x x x - 3\n", | |
"o x x o x o x x 4\n", | |
"x x x o x - o - 5\n", | |
"o x x x x - - o 6\n", | |
"o x x x x - - - 7\n", | |
"\n", | |
"\n", | |
"0 1 2 3 4 5 6 7\n", | |
"- x x x x x - o 0\n", | |
"o o o o o o o o 1\n", | |
"o o - x x x o x 2\n", | |
"o x x o x o x - 3\n", | |
"o x x o o o x x 4\n", | |
"x x x o x - o - 5\n", | |
"o x x x x - - o 6\n", | |
"o x x x x - - - 7\n", | |
"\n", | |
"\n", | |
"0 1 2 3 4 5 6 7\n", | |
"x x x x x x - o 0\n", | |
"x o o o o o o o 1\n", | |
"x o - x x x o x 2\n", | |
"x x x o x o x - 3\n", | |
"x x x o o o x x 4\n", | |
"x x x o x - o - 5\n", | |
"o x x x x - - o 6\n", | |
"o x x x x - - - 7\n", | |
"\n", | |
"\n", | |
"0 1 2 3 4 5 6 7\n", | |
"x x x x x x - o 0\n", | |
"x o o o o o o o 1\n", | |
"x o - x x x o x 2\n", | |
"x x x o x o x - 3\n", | |
"x x x o o o x x 4\n", | |
"x x x o o - o - 5\n", | |
"o o o o o o - o 6\n", | |
"o x x x x - - - 7\n", | |
"\n", | |
"\n", | |
"0 1 2 3 4 5 6 7\n", | |
"x x x x x x - o 0\n", | |
"x x x x o o o o 1\n", | |
"x x x x x x o x 2\n", | |
"x x x o x o x - 3\n", | |
"x x x o o o x x 4\n", | |
"x x x o o - o - 5\n", | |
"o o o o o o - o 6\n", | |
"o x x x x - - - 7\n", | |
"\n", | |
"\n", | |
"0 1 2 3 4 5 6 7\n", | |
"x x x x x x - o 0\n", | |
"x x x x o o o o 1\n", | |
"x x x x x x o o 2\n", | |
"x x x o x o o o 3\n", | |
"x x x o o o x x 4\n", | |
"x x x o o - o - 5\n", | |
"o o o o o o - o 6\n", | |
"o x x x x - - - 7\n", | |
"\n", | |
"\n", | |
"0 1 2 3 4 5 6 7\n", | |
"x x x x x x - o 0\n", | |
"x x x x o o o o 1\n", | |
"x x x x x x o o 2\n", | |
"x x x x x x o o 3\n", | |
"x x x o x x x x 4\n", | |
"x x x x x x o - 5\n", | |
"o o o o x o - o 6\n", | |
"o x x x x - - - 7\n", | |
"\n", | |
"\n", | |
"0 1 2 3 4 5 6 7\n", | |
"x x x x x x - o 0\n", | |
"x x x x o o o o 1\n", | |
"x x x x x x o o 2\n", | |
"x x x x x x o o 3\n", | |
"x x x o x x x o 4\n", | |
"x x x x x x o o 5\n", | |
"o o o o x o - o 6\n", | |
"o x x x x - - - 7\n", | |
"\n", | |
"\n", | |
"0 1 2 3 4 5 6 7\n", | |
"x x x x x x x o 0\n", | |
"x x x x o x x o 1\n", | |
"x x x x x x x o 2\n", | |
"x x x x x x x o 3\n", | |
"x x x o x x x o 4\n", | |
"x x x x x x o o 5\n", | |
"o o o o x o - o 6\n", | |
"o x x x x - - - 7\n", | |
"\n", | |
"\n", | |
"0 1 2 3 4 5 6 7\n", | |
"x x x x x x x o 0\n", | |
"x x x x o x x o 1\n", | |
"x x x x x x x o 2\n", | |
"x x x x x x x o 3\n", | |
"x x x o x x x o 4\n", | |
"x x x x x x o o 5\n", | |
"o o o o x o - o 6\n", | |
"o o o o o o - - 7\n", | |
"\n", | |
"\n", | |
"0 1 2 3 4 5 6 7\n", | |
"x x x x x x x o 0\n", | |
"x x x x o x x o 1\n", | |
"x x x x x x x o 2\n", | |
"x x x x x x x o 3\n", | |
"x x x o x x x o 4\n", | |
"x x x x x x o o 5\n", | |
"o o o o x x - o 6\n", | |
"o o o o o o x - 7\n", | |
"\n", | |
"\n", | |
"0 1 2 3 4 5 6 7\n", | |
"x x x x x x x o 0\n", | |
"x x x x o x x o 1\n", | |
"x x x x x x x o 2\n", | |
"x x x x x x x o 3\n", | |
"x x x o x x x o 4\n", | |
"x x x x x x o o 5\n", | |
"o o o o x x - o 6\n", | |
"o o o o o o o o 7\n", | |
"\n", | |
"\n", | |
"0 1 2 3 4 5 6 7\n", | |
"x x x x x x x o 0\n", | |
"x x x x o x x o 1\n", | |
"x x x x x x x o 2\n", | |
"x x x x x x x o 3\n", | |
"x x x o x x x o 4\n", | |
"x x x x x x x o 5\n", | |
"o o o o x x x o 6\n", | |
"o o o o o o o o 7\n", | |
"\n", | |
"\n", | |
"score -22\n" | |
] | |
} | |
], | |
"source": [ | |
"reward, log = run(qnet, agent, train=False, epsilon=0)\n", | |
"for c, step in enumerate(log):\n", | |
" localenv = Environ(board=step)\n", | |
" localenv.show_board()\n", | |
" print ''\n", | |
"print \"score %d\" % reward" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"念のため、完全にランダムに打つ相手との対戦結果を確認します。この結果では、勝率は75%になりました。興味深いことに、単純なミニマックス法に対する場合より勝率が下がりました。単純なミニマックス法を相手に学習したため、ランダムな手で生成される盤面に対する評価関数が適切に学習されていない可能性が考えられます。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"def run_random(qnet, agent, train, epsilon):\n", | |
" env = Environ()\n", | |
" log = []\n", | |
" episode_record = []\n", | |
"\n", | |
" while (True):\n", | |
" # AI Player\n", | |
" episode_record.append(copy.deepcopy(env.board))\n", | |
" action = agent.get_action(qnet, env, player=1, train=train, epsilon=epsilon)\n", | |
" if action:\n", | |
" env.update_state(action, player=1)\n", | |
" log.append(copy.deepcopy(env.board))\n", | |
" \n", | |
" # Human Player\n", | |
" action = agent.get_action(qnet, env, player=-1, train=True, epsilon=1.0)\n", | |
" if action:\n", | |
" env.update_state(action, player=-1)\n", | |
" log.append(copy.deepcopy(env.board))\n", | |
" reward = env.get_reward(player=1)\n", | |
" if reward != 0:\n", | |
" episode_record.append(reward) # Episode ends with the reward value.\n", | |
" break\n", | |
" \n", | |
" return reward, log" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"+ + + + - - - - - + - + + + - + + + + + + + + + + + + + + + + - - + + - + + + + - + + - + + + + + + + - + + + + - + - - + + + + + - + - + - - + + + - + + + + + + + + + + - + + - + + - + + + + + + - + \n", | |
"Average score 20.8, Winning rate 75%\n" | |
] | |
} | |
], | |
"source": [ | |
"n = 100\n", | |
"score, winrate = 0.0, 0.0\n", | |
"for _ in range(n):\n", | |
" reward, log = run_random(qnet, agent, train=False, epsilon=0)\n", | |
" if reward > 0:\n", | |
" winrate += 1.0\n", | |
" score += reward\n", | |
" print '+',\n", | |
" else:\n", | |
" print '-',\n", | |
"print ''\n", | |
"print 'Average score %.1f, Winning rate %d%%' % (score/n, winrate*100/n)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## 演習課題" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
" 1. このモデルには多数のチューニング可能なパラメーターが含まれています。\n", | |
" \n", | |
" - ニューラルネットワークの構造\n", | |
" - 評価値を帰納的に定義する際のγの値\n", | |
" - 学習処理を実施する際のバッチサイズ、εの減少率\n", | |
" - 対戦相手に設定するεの値\n", | |
" - その他\n", | |
"\n", | |
" これらのパラメーターを変化させることで、性能をさらに向上させてください。\n", | |
"\n", | |
" 2. 次の点について、実験と考察を行ってください。\n", | |
" \n", | |
" - 複数の対戦相手(単純なミニマックス法+ランダムに打つ相手)を混ぜて学習することで学習結果はどのように変化するでしょうか?\n", | |
" - 打ち手の傾向を観察すると「角」を取ると有利になるという事実を完全には学習していないようです。このような定石を効率的に学習させることはできるでしょうか?\n", | |
" - その他" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 2", | |
"language": "python", | |
"name": "python2" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 2 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython2", | |
"version": "2.7.5" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 1 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
× シュミレーション
○ シミュレーション
ランダムプレイヤーより、「自分のコマ数 - 相手のコマ数」の評価関数を持つプレイヤーのほうが強い前提で議論が進んでいますが、それは自明ではありません。
実際、リバーシにおいて、序盤は自分の駒は少ないほうが中終盤で石の置ける升が増えるため良いとされることが多く、序盤の打ち方については、「自分のコマ数 - 相手のコマ数」の評価関数よりは、ランダムプレイヤーの打ち方のほうがはるかにマシだからです。
なので、ランダムプレイヤーと「自分のコマ数 - 相手のコマ数」の評価関数を持つプレイヤーとの勝率を先に検証しないと、上記引用部分が本当に「興味深い」かどうかは、言えません。