Created
August 13, 2020 02:20
-
-
Save chunsj/3cef615211a7745dabb1a8c35053c996 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
(defpackage :gdrl-ch11 | |
(:use #:common-lisp | |
#:mu | |
#:th | |
#:th.layers | |
#:th.env | |
#:th.env.cartpole)) | |
(in-package :gdrl-ch11) | |
(defun train-env (&optional (max-steps 300)) (cartpole-env :easy :reward max-steps)) | |
(defun eval-env () (cartpole-env :eval)) | |
(defun clamp-probs (probs) | |
($clamp probs | |
single-float-epsilon | |
(- 1 single-float-epsilon))) | |
(defun $logPs (probs) ($log (clamp-probs probs))) | |
;; | |
;; REINFORCE | |
;; | |
(defun model (&optional (ni 4) (no 2)) | |
(let ((h 8)) | |
(sequential-layer | |
(affine-layer ni h :weight-initializer :random-uniform | |
:activation :relu) | |
(affine-layer h no :weight-initializer :random-uniform | |
:activation :softmax)))) | |
(defun policy (m state &optional (trainp T)) | |
(let ((s (if (eq ($ndim state) 1) | |
($unsqueeze state 0) | |
state))) | |
($execute m s :trainp trainp))) | |
(defun select-action (m state &optional (trainp T)) | |
(let* ((probs (policy m state trainp)) | |
(logPs ($logPs probs)) | |
(ps (if ($parameterp probs) ($data probs) probs)) | |
(entropy ($- ($dot ps logPs))) | |
(action ($multinomial ps 1)) | |
(logP ($gather logPs 1 action))) | |
(list ($scalar action) logP entropy))) | |
(defun action-selector (m) | |
(lambda (state) | |
(let ((probs (policy m state nil))) | |
($scalar ($argmax probs 1))))) | |
(defun reinforce (m &optional (max-episodes 4000)) | |
"REINFORCE updating per every episode." | |
(let* ((gamma 0.99) | |
(lr 0.01) | |
(env (train-env)) | |
(avg-score nil) | |
(success nil)) | |
(loop :while (not success) | |
:repeat max-episodes | |
:for e :from 1 | |
:for state = (env/reset! env) | |
:for rewards = '() | |
:for logPs = '() | |
:for score = 0 | |
:for done = nil | |
:do (let ((losses nil)) | |
(loop :while (not done) | |
:for (action logP entropy) = (select-action m state) | |
:for (next-state reward terminalp) = (cdr (env/step! env action)) | |
:do (progn | |
(push logP logPs) | |
(push reward rewards) | |
(incf score reward) | |
(setf state next-state | |
done terminalp))) | |
(setf logPs (reverse logPs)) | |
(setf rewards (rewards (reverse rewards) gamma T)) | |
(loop :for logP :in logPs | |
:for vt :in rewards | |
:for i :from 0 | |
:for gm = (expt gamma i) | |
:for l = ($- ($* gm logP vt)) | |
;; in practice, we don't have to collect losses. | |
;; each loss has independent computational graph. | |
:do (push l losses)) | |
($amgd! m lr) | |
(if (null avg-score) | |
(setf avg-score score) | |
(setf avg-score (+ (* 0.9 avg-score) (* 0.1 score)))) | |
(when (zerop (rem e 100)) | |
(let ((escore (cadr (evaluate (eval-env) (action-selector m))))) | |
(if (and (>= avg-score (* 0.9 300)) (>= escore 3000)) (setf success T)) | |
(prn (format nil "~5D: ~8,2F / ~5,0F" e avg-score escore)))))) | |
avg-score)) | |
;; train with REINFORCE | |
(defparameter *m* (model)) | |
(reinforce *m* 4000) | |
;; evaluation | |
(evaluate (eval-env) (action-selector *m*)) | |
;; | |
;; REINFORCE - batch updating | |
;; | |
(defun select-action (m state) ($scalar ($multinomial (policy m state nil) 1))) | |
(defun trace-episode (env m gamma &optional (nb 1)) | |
"collect episode trajectories with given policy model" | |
(let ((states nil) | |
(actions nil) | |
(rewards nil) | |
(gammas nil) | |
(done nil) | |
(score 0) | |
(state nil)) | |
(loop :repeat nb | |
:do (progn | |
(setf state (env/reset! env)) | |
(loop :while (not done) | |
:for action = (select-action m state) | |
:for (_ next-state reward terminalp) = (env/step! env action) | |
:for i :from 0 | |
:do (progn | |
(push ($list state) states) | |
(push action actions) | |
(push reward rewards) | |
(push (expt gamma i) gammas) | |
(incf score reward) | |
(setf state next-state | |
done terminalp))) | |
(setf done nil))) | |
(let ((n ($count states))) | |
(list (tensor (reverse states)) | |
(-> (tensor.long (reverse actions)) | |
($reshape! n 1)) | |
(-> (rewards (reverse rewards) gamma T) | |
(tensor) | |
($reshape! n 1)) | |
(-> (tensor (reverse gammas)) | |
($reshape! n 1)) | |
(/ score nb))))) | |
(defun compute-loss (m states actions rewards gammas) | |
(let ((logPs ($gather ($logPs (policy m states)) 1 actions))) | |
($mean ($* -1 gammas rewards logPs)))) | |
(defun reinforce (m &optional (nbatch 5) (max-episodes 4000)) | |
"REINFORCE with batch updating" | |
(let* ((gamma 0.99) | |
(lr 0.04) | |
(env (train-env)) | |
(avg-score nil) | |
(success nil)) | |
(loop :while (not success) | |
:repeat (round (/ max-episodes nbatch)) | |
:for e :from 1 | |
;;:for state = (env/reset! env) | |
:do (let* ((res (trace-episode env m gamma nbatch)) | |
(states ($0 res)) | |
(actions ($1 res)) | |
(rewards ($2 res)) | |
(gammas ($3 res)) | |
(score ($4 res)) | |
(loss nil)) | |
(setf loss (compute-loss m states actions rewards gammas)) | |
($amgd! m lr) | |
(if (null avg-score) | |
(setf avg-score score) | |
(setf avg-score (+ (* 0.9 avg-score) (* 0.1 score)))) | |
(when (zerop (rem e 100)) | |
(let ((escore (cadr (evaluate (eval-env) (action-selector m))))) | |
(if (and (>= avg-score (* 0.9 300)) (>= escore 3000)) (setf success T)) | |
(prn (format nil "~5D: ~8,2F / ~5,0F ~12,4F" e avg-score escore | |
($scalar ($data loss)))))))) | |
avg-score)) | |
(defparameter *m* (model)) | |
(reinforce *m* 10 4000) | |
(evaluate (eval-env) (action-selector *m*)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment