nn.CDivTable在backward被调用时是否有有效的原因引起错误?

我最近开始使用Torch框架和Lua脚本语言进行神经网络的探索。我掌握了线性网络的基础知识,因此尝试了更加复杂但简单的东西:

想法是我有3个输入,我必须选择前两个,将它们除以并将结果转发给线性模块。因此,我制作了这个小脚本:

require "nn";
require "optim";

local N = 3;

local input = torch.Tensor{
    {1, 2, 3},
    {9, 20, 20},
    {9, 300, 1},
};

local output = torch.Tensor(N);
for i=1, N do
    output[i] = 1;
end

local ratioPerceptron = nn.Sequential();
ratioPerceptron:add(nn.Narrow(1, 1, 2));
ratioPerceptron:add(nn.CDivTable());
ratioPerceptron:add(nn.Reshape(N, 1));
ratioPerceptron:add(nn.Linear(1, 1));
ratioPerceptron:add(nn.Sigmoid());

local criterion = nn.BCECriterion();
local params, gradParams = ratioPerceptron:getParameters();
local optimState = {learningRate = 0.01};

local maxIteration = 100000;
for i=1, maxIteration do
    local function f(params)
        gradParams:zero();

        local outputs = ratioPerceptron:forward(input);
        local loss = criterion:forward(outputs, output);
        local dloss_doutputs = criterion:backward(outputs, output);
        ratioPerceptron:backward(input, dloss_doutputs);

        return loss, gradParams;
    end

    optim.sgd(f, params, optimState);
end

在训练期间调用backward时,这将失败并显示以下错误:

CDivTable.lua:21: both torch.LongStorage and (null) have no addition operator

但如果我从sequential模块中删除CDivTable,并将nn.Reshape和nn.Linear更改为二维输入(因为我们删除了将二维输入除以产生一维输出的CDivTable),则会像这样完全没有错误地完成训练:

local ratioPerceptron = nn.Sequential();
ratioPerceptron:add(nn.Narrow(1, 1, 2));
ratioPerceptron:add(nn.Reshape(N, 2));
ratioPerceptron:add(nn.Linear(2, 1));
ratioPerceptron:add(nn.Sigmoid());

是否有其他方法可以将两个选择的输入除以并将结果转发到线性模块?

点赞
用户3754413
用户3754413

模块 CDivTable 接受一个表格作为输入,并将第一个表格的元素除以第二个表格的元素。在这里,你将一个单独的输入馈送到你的网络中,而不是一个包含两个输入的表格。这就是为什么我认为你遇到了 null 的错误。Torch 无法理解你的输入(它由两个向量组成)应该被视为两个向量的表格。它只看到一个大小为 2x3 的张量!因此,你必须告诉 Torch从输入中创建一个表格。因此,你可以使用模块SplitTable(dim),它将沿着维度 dim 将输入拆分为表格。

在狭缩模块之后插入这一行 ratioPerceptron:add(nn.SplitTable(1))

local ratioPerceptron = nn.Sequential();
ratioPerceptron:add(nn.Narrow(1, 1, 2));
ratioPerceptron:add(nn.SplitTable(1))
ratioPerceptron:add(nn.CDivTable());
ratioPerceptron:add(nn.Reshape(N, 1));
ratioPerceptron:add(nn.Linear(1, 1));
ratioPerceptron:add(nn.Sigmoid());

此外,在遇到此类错误时,我建议你通过加入print语句来查看网络计算的内容:在添加导致错误的模块的代码行之前加入一行 print(ratioPerceptron:forward(input))

2017-04-28 14:21:05