Skip to content

Instantly share code, notes, and snippets.

@simgt
Created June 23, 2016 12:43

Revisions

  1. @notsimon notsimon created this gist Jun 23, 2016.
    115 changes: 115 additions & 0 deletions rnn-sine-waves.lua
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,115 @@
    local torch =require 'torch'
    local nn =require 'nn'
    local rnn =require 'rnn'
    local gnuplot = require 'gnuplot'

    torch.setnumthreads(4)
    print('number of threads: ' .. torch.getnumthreads())

    batchSize = 16
    rho = 16 -- sequence length
    lr = 0.01

    --
    -- Model
    --

    inputSize = 1
    hiddenSize = 32
    outputSize = 1

    local r = nn.Recurrent(
    hiddenSize, -- start
    nn.Linear(inputSize, hiddenSize), -- input
    nn.Linear(hiddenSize, hiddenSize), -- feedback
    nn.Sigmoid(), -- transfer
    rho
    )

    local rnn = nn.Sequential()
    :add(r)
    :add(nn.Linear(hiddenSize, outputSize))

    rnn = nn.Sequencer(rnn)

    -- load a model previously trained
    --rnn = torch.load('sine-waves-model.dat', 'ascii', true)

    print(rnn)

    criterion = nn.SequencerCriterion(nn.MSECriterion())

    --
    -- Dataset
    --

    local numSamples = 1024
    local numPeriods = 10

    local t = torch.linspace(0, numPeriods * 2 * math.pi, numSamples)
    local input = torch.Tensor(numSamples, inputSize)
    local output = torch.Tensor(numSamples, outputSize)
    input:select(2, 1):copy(torch.sin(t))
    --input:select(2, 2):copy(torch.sin(t/2))
    output:select(2, 1):copy(torch.sin(t/2))
    --output:select(2, 2):copy(torch.sin(t*2))

    --
    -- Training
    --

    local it = 1
    while true do
    offsets = torch.LongTensor(batchSize)
    for i=1,batchSize do
    offsets[i] = math.ceil(math.random()*input:size(1))
    end

    for a = 1, 2000 do
    -- create a batch of sequences of rho time-steps
    local x, y = {}, {}
    for step = 1, rho do
    x[step] = input:index(1, offsets)
    y[step] = output:index(1, offsets)

    -- incement indices
    offsets = offsets + 1
    for j = 1, batchSize do
    if offsets[j] > numSamples then
    offsets[j] = 1
    end
    end
    end

    -- forward the sequence
    local z = rnn:forward(x)
    local err = criterion:forward(z, y)

    print(string.format("[%d] err = %f", it, err / rho))

    -- backward the sequence (i.e. BPTT) in reverse order of forward calls
    rnn:zeroGradParameters()
    local gz = criterion:backward(z, y)
    rnn:backward(x, gz)

    -- update
    rnn:updateParameters(lr)

    it = it + 1
    end

    -- save the model
    print("Saving...")
    torch.save('sine-waves-model.dat', rnn, 'ascii', true)

    -- test on the full sequence
    local z = rnn:forward(input)

    gnuplot.pngfigure('sine-waves-test.png')
    gnuplot.plot(
    {'input', t, input:select(2, 1), '-'},
    {'truth', t, output:select(2, 1), '-'},
    {'estimate', t, z:select(2, 1), '-'}
    )
    gnuplot.plotflush()
    end