Skip to content

Instantly share code, notes, and snippets.

@chunsj
Created July 17, 2019 04:09
Show Gist options
  • Select an option

  • Save chunsj/390c4e1630cbc50ea75ad27c5c6a6b01 to your computer and use it in GitHub Desktop.

Select an option

Save chunsj/390c4e1630cbc50ea75ad27c5c6a6b01 to your computer and use it in GitHub Desktop.
Obama speech generation
;; from
;; http://karpathy.github.io/2015/05/21/rnn-effectiveness/
(defpackage :genchars-obama-lstm
(:use #:common-lisp
#:mu
#:th
#:th.ex.data))
(in-package :genchars-obama-lstm)
(th::th-set-num-threads 12)
(th::th-set-gc-hard-max (* 8 1024 1024 1024))
(defparameter *data-lines* (remove-if (lambda (line) (< ($count line) 1)) (text-lines :obama)))
(defparameter *data* (format nil "~{~A~^~%~}" *data-lines*))
(defparameter *chars* (remove-duplicates (coerce *data* 'list)))
(defparameter *data-size* ($count *data*))
(defparameter *vocab-size* ($count *chars*))
(defparameter *char-to-idx* (let ((ht #{}))
(loop :for i :from 0 :below *vocab-size*
:for ch = ($ *chars* i)
:do (setf ($ ht ch) i))
ht))
(defparameter *idx-to-char* *chars*)
(defun choose (probs)
(let* ((sprobs ($sum probs))
(probs ($div probs sprobs)))
($ ($reshape! ($multinomial probs 1) ($count probs)) 0)))
;;
;; non batched lstm for example
;;
(defparameter *hidden-size* 100)
(defparameter *sequence-length* 50)
(defparameter *lstm* (parameters))
(defparameter *wa* ($push *lstm* ($- ($* 0.16 (rnd *vocab-size* *hidden-size*)) 0.08)))
(defparameter *ua* ($push *lstm* ($- ($* 0.16 (rnd *hidden-size* *hidden-size*)) 0.08)))
(defparameter *ba* ($push *lstm* ($- ($* 0.16 (rnd *hidden-size*)) 0.08)))
(defparameter *wi* ($push *lstm* ($- ($* 0.16 (rnd *vocab-size* *hidden-size*)) 0.08)))
(defparameter *ui* ($push *lstm* ($- ($* 0.16 (rnd *hidden-size* *hidden-size*)) 0.08)))
(defparameter *bi* ($push *lstm* ($- ($* 0.16 (rnd *hidden-size*)) 0.08)))
(defparameter *wf* ($push *lstm* ($- ($* 0.16 (rnd *vocab-size* *hidden-size*)) 0.08)))
(defparameter *uf* ($push *lstm* ($- ($* 0.16 (rnd *hidden-size* *hidden-size*)) 0.08)))
(defparameter *bf* ($push *lstm* (ones *hidden-size*)))
(defparameter *wo* ($push *lstm* ($- ($* 0.16 (rnd *vocab-size* *hidden-size*)) 0.08)))
(defparameter *uo* ($push *lstm* ($- ($* 0.16 (rnd *hidden-size* *hidden-size*)) 0.08)))
(defparameter *bo* ($push *lstm* ($- ($* 0.16 (rnd *hidden-size*)) 0.08)))
(defparameter *wy* ($push *lstm* ($- ($* 0.16 (rnd *hidden-size* *vocab-size*)) 0.08)))
(defparameter *by* ($push *lstm* ($- ($* 0.16 (rnd *vocab-size*)) 0.08)))
(defun lstm-write-weight-to (w fname)
(let ((f (file.disk fname "w")))
($fwrite ($data w) f)
($fclose f)))
(defun lstm-read-weight-from (w fname)
(let ((f (file.disk fname "r")))
($fread ($data w) f)
($fclose f)))
(defun lstm-write-weights ()
(lstm-write-weight-to *wa* "examples/weights/genchar-obama-lstm/lstm-wa.dat")
(lstm-write-weight-to *ua* "examples/weights/genchar-obama-lstm/lstm-ua.dat")
(lstm-write-weight-to *ba* "examples/weights/genchar-obama-lstm/lstm-ba.dat")
(lstm-write-weight-to *wi* "examples/weights/genchar-obama-lstm/lstm-wi.dat")
(lstm-write-weight-to *ui* "examples/weights/genchar-obama-lstm/lstm-ui.dat")
(lstm-write-weight-to *bi* "examples/weights/genchar-obama-lstm/lstm-bi.dat")
(lstm-write-weight-to *wf* "examples/weights/genchar-obama-lstm/lstm-wf.dat")
(lstm-write-weight-to *uf* "examples/weights/genchar-obama-lstm/lstm-uf.dat")
(lstm-write-weight-to *bf* "examples/weights/genchar-obama-lstm/lstm-bf.dat")
(lstm-write-weight-to *wo* "examples/weights/genchar-obama-lstm/lstm-wo.dat")
(lstm-write-weight-to *uo* "examples/weights/genchar-obama-lstm/lstm-uo.dat")
(lstm-write-weight-to *bo* "examples/weights/genchar-obama-lstm/lstm-bo.dat")
(lstm-write-weight-to *wy* "examples/weights/genchar-obama-lstm/lstm-wy.dat")
(lstm-write-weight-to *by* "examples/weights/genchar-obama-lstm/lstm-by.dat"))
(defun lstm-read-weights ()
(lstm-read-weight-from *wa* "examples/weights/genchar-obama-lstm/lstm-wa.dat")
(lstm-read-weight-from *ua* "examples/weights/genchar-obama-lstm/lstm-ua.dat")
(lstm-read-weight-from *ba* "examples/weights/genchar-obama-lstm/lstm-ba.dat")
(lstm-read-weight-from *wi* "examples/weights/genchar-obama-lstm/lstm-wi.dat")
(lstm-read-weight-from *ui* "examples/weights/genchar-obama-lstm/lstm-ui.dat")
(lstm-read-weight-from *bi* "examples/weights/genchar-obama-lstm/lstm-bi.dat")
(lstm-read-weight-from *wf* "examples/weights/genchar-obama-lstm/lstm-wf.dat")
(lstm-read-weight-from *uf* "examples/weights/genchar-obama-lstm/lstm-uf.dat")
(lstm-read-weight-from *bf* "examples/weights/genchar-obama-lstm/lstm-bf.dat")
(lstm-read-weight-from *wo* "examples/weights/genchar-obama-lstm/lstm-wo.dat")
(lstm-read-weight-from *uo* "examples/weights/genchar-obama-lstm/lstm-uo.dat")
(lstm-read-weight-from *bo* "examples/weights/genchar-obama-lstm/lstm-bo.dat")
(lstm-read-weight-from *wy* "examples/weights/genchar-obama-lstm/lstm-wy.dat")
(lstm-read-weight-from *by* "examples/weights/genchar-obama-lstm/lstm-by.dat"))
(defun cindices (str)
(let ((m (zeros ($count str) *vocab-size*)))
(loop :for i :from 0 :below ($count str)
:for ch = ($ str i)
:do (setf ($ m i ($ *char-to-idx* ch)) 1))
m))
(defun rstrings (indices) (coerce (mapcar (lambda (i) ($ *idx-to-char* i)) indices) 'string))
(defun seedh (str &optional (temperature 1))
(let ((input (cindices str))
(ph (zeros 1 *hidden-size*))
(pc (zeros 1 *hidden-size*))
(wa ($data *wa*))
(ua ($data *ua*))
(ba ($data *ba*))
(wi ($data *wi*))
(ui ($data *ui*))
(bi ($data *bi*))
(wf ($data *wf*))
(uf ($data *uf*))
(bf ($data *bf*))
(wo ($data *wo*))
(uo ($data *uo*))
(bo ($data *bo*))
(wy ($data *wy*))
(by ($data *by*))
(ncidx 0))
(loop :for i :from 0 :below ($size input 0)
:for xt = ($index input 0 i)
:for (ht ct) = ($lstm xt ph pc wi ui wf uf wo uo wa ua bi bf bo ba)
:for yt = ($affine ht wy by)
:for ps = ($softmax ($/ yt temperature))
:for nidx = (choose ps)
:do (setf ph ht
pc ct
ncidx nidx))
(list ncidx ph pc)))
(defun sample (str n &optional (temperature 1))
(let ((x (zeros 1 *vocab-size*))
(indices nil)
(sh (when str (seedh str temperature)))
(wa ($data *wa*))
(ua ($data *ua*))
(ba ($data *ba*))
(wi ($data *wi*))
(ui ($data *ui*))
(bi ($data *bi*))
(wf ($data *wf*))
(uf ($data *uf*))
(bf ($data *bf*))
(wo ($data *wo*))
(uo ($data *uo*))
(bo ($data *bo*))
(wy ($data *wy*))
(by ($data *by*))
(ph nil)
(pc nil))
(if sh
(let ((idx0 ($0 sh))
(h ($1 sh))
(c ($2 sh)))
(setf ($ x 0 idx0) 1)
(setf ph h
pc c)
(push idx0 indices))
(let ((idx0 (random *vocab-size*))
(h (zeros 1 *hidden-size*))
(c (zeros 1 *hidden-size*)))
(setf ($ x 0 idx0) 1)
(setf ph h
pc c)
(push idx0 indices)))
(loop :for i :from 0 :below n
:for (ht ct) = ($lstm x ph pc wi ui wf uf wo uo wa ua bi bf bo ba)
:for yt = ($affine ht wy by)
:for ps = ($softmax ($/ yt temperature))
:for nidx = (choose ps)
:do (progn
(setf ph ht
pc ct)
(push nidx indices)
($zero! x)
(setf ($ x 0 nidx) 1)))
(concatenate 'string str (rstrings (reverse indices)))))
(defparameter *upto* (- *data-size* *sequence-length* 1))
;; XXX of course, we need better strategy for building data
;; for example, breaking at the word level will be better one.
(defparameter *inputs* (loop :for p :from 0 :below *upto* :by *sequence-length*
:for input-str = (subseq *data* p (+ p *sequence-length*))
:collect (let ((m (zeros *sequence-length* *vocab-size*)))
(loop :for i :from 0 :below *sequence-length*
:for ch = ($ input-str i)
:do (setf ($ m i ($ *char-to-idx* ch)) 1))
m)))
(defparameter *targets* (loop :for p :from 0 :below *upto* :by *sequence-length*
:for target-str = (subseq *data* (1+ p) (+ p *sequence-length* 1))
:collect (let ((m (zeros *sequence-length* *vocab-size*)))
(loop :for i :from 0 :below *sequence-length*
:for ch = ($ target-str i)
:do (setf ($ m i ($ *char-to-idx* ch)) 1))
m)))
(defparameter *mloss* (* (- (log (/ 1 *vocab-size*))) *sequence-length*))
(defparameter *min-mloss* *mloss*)
($cg! *lstm*)
(gcf)
(time
(loop :for iter :from 1 :to 50
:for n = 0
:for maxloss = 0
:for maxloss-pos = -1
:for max-mloss = 0
:do (progn
(loop :for input :in *inputs*
:for target :in *targets*
:do (let ((ph (zeros 1 *hidden-size*))
(pc (zeros 1 *hidden-size*))
(tloss 0))
(loop :for i :from 0 :below ($size input 0)
:for xt = ($index input 0 i)
:for (ht ct) = ($lstm xt ph pc *wi* *ui* *wf* *uf* *wo* *uo* *wa* *ua*
*bi* *bf* *bo* *ba*)
:for yt = ($affine ht *wy* *by*)
:for ps = ($softmax yt)
:for y = ($index target 0 i)
:for l = ($cee ps y)
:do (progn
(setf ph ht
pc ct)
(incf tloss ($data l))))
(when (> tloss maxloss)
(setf maxloss-pos n)
(setf maxloss tloss))
($rmgd! *lstm*)
(setf *mloss* (+ (* 0.999 *mloss*) (* 0.001 tloss)))
(when (> *mloss* max-mloss) (setf max-mloss *mloss*))
(when (zerop (rem n 200))
(prn "[ITER]" iter n *mloss* maxloss maxloss-pos))
(incf n)))
(when (< max-mloss *min-mloss*)
(prn "*** BETTER MLOSS - WRITE WEIGHTS: FROM" *min-mloss* "TO" max-mloss)
(setf *min-mloss* max-mloss)
(lstm-write-weights)))))
(prn (sample "This is not correct." 200 0.5))
(prn (sample "I" 200 0.5))
(lstm-write-weights)
(lstm-read-weights)
;; rmgd 0.002 0.99 - 1.31868 - 1.61637
;; adgd - 1.551497 - 1.841827
;; amgd 0.002 - 1.3747485 - 1.70623
(loop :for p :from 0 :below *upto* :by *sequence-length*
:for n :from 0
:for input-str = (subseq *data* p (+ p *sequence-length*))
:do (when (member n '(75856 44515 44514 21663 18796 1258 336 178))
(prn (format nil "~6,d" n) input-str)))
I want to work for the strength of the time they should be a distance and face of the moment. That's what the months recovered in the people that he will be some cases have all of the groung the political to chance to lost the dreams. And if they don't do something -- the goods. It is a completed to lead to prevent on the courage in free the bravely after you had the debt to pass the more the promise of the prayer in Americans recognize the -- in Americans who are insurance same to be vote the families of Americans and all of the country again the truth. And when it is more than every decade of your longer with the months. And this country thing that we can stand for so it is a stravery and for the forces to pay the world and get in the or where we are what we can be some with the responsible to be able to make sure that everybody will be do on the world for what we will be the people who had forward that we have care the promise of the talking about all of the politics and a more of your politics at you who had a world very strong more than the notion to focused to make the American million jobs to honor of the debt to help for more than the story is an exploits to afford to the forces that believed to the American college to prove a family was a process in the company was company new company who will be secure and jobs to get the loved our concern the only more than minds to make now and the murder just as the things that we can be made on the states of the people -- they're her long and a thing they know what the decision in the good jobs and health care traming for years ago, they don't want to be person this was a decision and the companies of yourselves for it decide that what it would all this can opposed it where you serve the world with the world in a constant has the most promise. It was a constitution to law. And the next time. And that's what we had the get college.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment