Created
June 23, 2016 12:43
Revisions
-
notsimon created this gist
Jun 23, 2016 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,115 @@ local torch =require 'torch' local nn =require 'nn' local rnn =require 'rnn' local gnuplot = require 'gnuplot' torch.setnumthreads(4) print('number of threads: ' .. torch.getnumthreads()) batchSize = 16 rho = 16 -- sequence length lr = 0.01 -- -- Model -- inputSize = 1 hiddenSize = 32 outputSize = 1 local r = nn.Recurrent( hiddenSize, -- start nn.Linear(inputSize, hiddenSize), -- input nn.Linear(hiddenSize, hiddenSize), -- feedback nn.Sigmoid(), -- transfer rho ) local rnn = nn.Sequential() :add(r) :add(nn.Linear(hiddenSize, outputSize)) rnn = nn.Sequencer(rnn) -- load a model previously trained --rnn = torch.load('sine-waves-model.dat', 'ascii', true) print(rnn) criterion = nn.SequencerCriterion(nn.MSECriterion()) -- -- Dataset -- local numSamples = 1024 local numPeriods = 10 local t = torch.linspace(0, numPeriods * 2 * math.pi, numSamples) local input = torch.Tensor(numSamples, inputSize) local output = torch.Tensor(numSamples, outputSize) input:select(2, 1):copy(torch.sin(t)) --input:select(2, 2):copy(torch.sin(t/2)) output:select(2, 1):copy(torch.sin(t/2)) --output:select(2, 2):copy(torch.sin(t*2)) -- -- Training -- local it = 1 while true do offsets = torch.LongTensor(batchSize) for i=1,batchSize do offsets[i] = math.ceil(math.random()*input:size(1)) end for a = 1, 2000 do -- create a batch of sequences of rho time-steps local x, y = {}, {} for step = 1, rho do x[step] = input:index(1, offsets) y[step] = output:index(1, offsets) -- incement indices offsets = offsets + 1 for j = 1, batchSize do if offsets[j] > numSamples then offsets[j] = 1 end end end -- forward the sequence local z = rnn:forward(x) local err = criterion:forward(z, y) print(string.format("[%d] err = %f", it, err / rho)) -- backward the sequence (i.e. BPTT) in reverse order of forward calls rnn:zeroGradParameters() local gz = criterion:backward(z, y) rnn:backward(x, gz) -- update rnn:updateParameters(lr) it = it + 1 end -- save the model print("Saving...") torch.save('sine-waves-model.dat', rnn, 'ascii', true) -- test on the full sequence local z = rnn:forward(input) gnuplot.pngfigure('sine-waves-test.png') gnuplot.plot( {'input', t, input:select(2, 1), '-'}, {'truth', t, output:select(2, 1), '-'}, {'estimate', t, z:select(2, 1), '-'} ) gnuplot.plotflush() end