Skip to content

Instantly share code, notes, and snippets.

@chunsj
Created August 13, 2020 02:20
Show Gist options
  • Save chunsj/3cef615211a7745dabb1a8c35053c996 to your computer and use it in GitHub Desktop.
Save chunsj/3cef615211a7745dabb1a8c35053c996 to your computer and use it in GitHub Desktop.
(defpackage :gdrl-ch11
(:use #:common-lisp
#:mu
#:th
#:th.layers
#:th.env
#:th.env.cartpole))
(in-package :gdrl-ch11)
(defun train-env (&optional (max-steps 300)) (cartpole-env :easy :reward max-steps))
(defun eval-env () (cartpole-env :eval))
(defun clamp-probs (probs)
($clamp probs
single-float-epsilon
(- 1 single-float-epsilon)))
(defun $logPs (probs) ($log (clamp-probs probs)))
;;
;; REINFORCE
;;
(defun model (&optional (ni 4) (no 2))
(let ((h 8))
(sequential-layer
(affine-layer ni h :weight-initializer :random-uniform
:activation :relu)
(affine-layer h no :weight-initializer :random-uniform
:activation :softmax))))
(defun policy (m state &optional (trainp T))
(let ((s (if (eq ($ndim state) 1)
($unsqueeze state 0)
state)))
($execute m s :trainp trainp)))
(defun select-action (m state &optional (trainp T))
(let* ((probs (policy m state trainp))
(logPs ($logPs probs))
(ps (if ($parameterp probs) ($data probs) probs))
(entropy ($- ($dot ps logPs)))
(action ($multinomial ps 1))
(logP ($gather logPs 1 action)))
(list ($scalar action) logP entropy)))
(defun action-selector (m)
(lambda (state)
(let ((probs (policy m state nil)))
($scalar ($argmax probs 1)))))
(defun reinforce (m &optional (max-episodes 4000))
"REINFORCE updating per every episode."
(let* ((gamma 0.99)
(lr 0.01)
(env (train-env))
(avg-score nil)
(success nil))
(loop :while (not success)
:repeat max-episodes
:for e :from 1
:for state = (env/reset! env)
:for rewards = '()
:for logPs = '()
:for score = 0
:for done = nil
:do (let ((losses nil))
(loop :while (not done)
:for (action logP entropy) = (select-action m state)
:for (next-state reward terminalp) = (cdr (env/step! env action))
:do (progn
(push logP logPs)
(push reward rewards)
(incf score reward)
(setf state next-state
done terminalp)))
(setf logPs (reverse logPs))
(setf rewards (rewards (reverse rewards) gamma T))
(loop :for logP :in logPs
:for vt :in rewards
:for i :from 0
:for gm = (expt gamma i)
:for l = ($- ($* gm logP vt))
;; in practice, we don't have to collect losses.
;; each loss has independent computational graph.
:do (push l losses))
($amgd! m lr)
(if (null avg-score)
(setf avg-score score)
(setf avg-score (+ (* 0.9 avg-score) (* 0.1 score))))
(when (zerop (rem e 100))
(let ((escore (cadr (evaluate (eval-env) (action-selector m)))))
(if (and (>= avg-score (* 0.9 300)) (>= escore 3000)) (setf success T))
(prn (format nil "~5D: ~8,2F / ~5,0F" e avg-score escore))))))
avg-score))
;; train with REINFORCE
(defparameter *m* (model))
(reinforce *m* 4000)
;; evaluation
(evaluate (eval-env) (action-selector *m*))
;;
;; REINFORCE - batch updating
;;
(defun select-action (m state) ($scalar ($multinomial (policy m state nil) 1)))
(defun trace-episode (env m gamma &optional (nb 1))
"collect episode trajectories with given policy model"
(let ((states nil)
(actions nil)
(rewards nil)
(gammas nil)
(done nil)
(score 0)
(state nil))
(loop :repeat nb
:do (progn
(setf state (env/reset! env))
(loop :while (not done)
:for action = (select-action m state)
:for (_ next-state reward terminalp) = (env/step! env action)
:for i :from 0
:do (progn
(push ($list state) states)
(push action actions)
(push reward rewards)
(push (expt gamma i) gammas)
(incf score reward)
(setf state next-state
done terminalp)))
(setf done nil)))
(let ((n ($count states)))
(list (tensor (reverse states))
(-> (tensor.long (reverse actions))
($reshape! n 1))
(-> (rewards (reverse rewards) gamma T)
(tensor)
($reshape! n 1))
(-> (tensor (reverse gammas))
($reshape! n 1))
(/ score nb)))))
(defun compute-loss (m states actions rewards gammas)
(let ((logPs ($gather ($logPs (policy m states)) 1 actions)))
($mean ($* -1 gammas rewards logPs))))
(defun reinforce (m &optional (nbatch 5) (max-episodes 4000))
"REINFORCE with batch updating"
(let* ((gamma 0.99)
(lr 0.04)
(env (train-env))
(avg-score nil)
(success nil))
(loop :while (not success)
:repeat (round (/ max-episodes nbatch))
:for e :from 1
;;:for state = (env/reset! env)
:do (let* ((res (trace-episode env m gamma nbatch))
(states ($0 res))
(actions ($1 res))
(rewards ($2 res))
(gammas ($3 res))
(score ($4 res))
(loss nil))
(setf loss (compute-loss m states actions rewards gammas))
($amgd! m lr)
(if (null avg-score)
(setf avg-score score)
(setf avg-score (+ (* 0.9 avg-score) (* 0.1 score))))
(when (zerop (rem e 100))
(let ((escore (cadr (evaluate (eval-env) (action-selector m)))))
(if (and (>= avg-score (* 0.9 300)) (>= escore 3000)) (setf success T))
(prn (format nil "~5D: ~8,2F / ~5,0F ~12,4F" e avg-score escore
($scalar ($data loss))))))))
avg-score))
(defparameter *m* (model))
(reinforce *m* 10 4000)
(evaluate (eval-env) (action-selector *m*))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment