Created
July 9, 2015 21:20
-
-
Save ndronen/f6ce80b7343a73c18072 to your computer and use it in GitHub Desktop.
Minimal working example of something that doesn't work with nn.Concat.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env th | |
require 'nn'; | |
local cmd = torch.CmdLine() | |
cmd:text() | |
cmd:text("What's wrong with this use of nn.Concat?") | |
cmd:text('Options:') | |
cmd:option('-noConcat', false, 'do not include concat layer') | |
local opt = cmd:parse(arg or {}) | |
-- Mock up a training example with regression target 3. | |
local target = torch.Tensor(1):fill(3) | |
local input = torch.Tensor({{ 1,2,4,1 }}):resize(4) | |
-- A lookup table layer, output size 6. | |
local lookupSeq = nn.Sequential() | |
lookupSeq:add(nn.LookupTable(5, 6)) | |
lookupSeq:add(nn.Sum(1)) | |
-- A concatenation layer, output size 6. | |
local concat = nn.Concat(1) | |
if opt.noConcat then | |
concat = lookupSeq | |
else | |
-- Normally we'd add more than one module to the concat layer. | |
-- This is just to demonstrate the error. | |
concat:add(lookupSeq) | |
end | |
-- A simple regression model. | |
local model = nn.Sequential() | |
model:add(concat) | |
model:add(nn.Linear(6, 1)) | |
-- Do a forward and backward pass. | |
local output = model:forward(input) | |
local criterion = nn.MSECriterion() | |
local err = criterion:forward(output, target) | |
local df_do = criterion:backward(output, target) | |
--[[ | |
The backward pass on the model results in this error. | |
luajit: nn/Concat.lua:73: inconsistent tensor size at torch/pkg/torch/lib/TH/generic/THTensorCopy.c:7 | |
stack traceback: | |
[C]: in function 'copy' | |
nn/Concat.lua:73: in function 'backward' | |
nn/Sequential.lua:73: in function 'backward' | |
./repro.lua:31: in main chunk | |
[C]: in function 'dofile' | |
th:131: in main chunk | |
[C]: at 0x0100e112d0 | |
--]] | |
model:backward(input, df_do) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment