使用 nn.Linear 构建 Torch nngraph 节点。

我是 Torch/Lua 的新手,正在完成牛津机器学习课程的 practical5

我试图实现一个简单的层: m = x1 + x2 cmul linear(x3) 其中 cmul 是元素乘法,linear 只是一个线性层。

我的代码如下:

-- 线性层参数
params = {
    x3_size1 = 10,
    x3_size2 = 30
}
-- 用于作为图中的输入节点的虚拟节点
x1 = nn.Identity()()
x2 = nn.Identity()()
x3 = nn.Identity()()

-- 建立 x1 + x2 cmul linear(x3) 的计算图
l3 = nn.Linear(params.x3_size1, params.x3_size2)(x1)
m23 = nn.CMulTable()({x2,l3})
add = nn.CAddTable()({x1, m23})

-- 设置图的输入和输出
m = nn.gModule({x1,x2,x3}, {add})

graph.dot(mlp.fg, "mlp")

然而,我收到了以下错误消息:

  /Users/yiranzhang/torch/install/bin/luajit: /Users/yiranzhang/torch/install/share/lua/5.1/nn/Linear.lua:36: attempt to index local 'input' (a nil value)
stack traceback:
    /Users/yiranzhang/torch/install/share/lua/5.1/nn/Linear.lua:36: in function 'forward'
    /Users/yiranzhang/torch/install/share/lua/5.1/nn/Module.lua:232: in function </Users/yiranzhang/torch/install/share/lua/5.1/nn/Module.lua:231>
    [C]: at 0x0156d0d0
    practical5.lua:32: in main chunk
    [C]: in function 'dofile'
    ...hang/torch/install/lib/luarocks/rocks/trepl/scm-1/bin/th:131: in main chunk
    [C]: at 0x01013242e0

如果我只想要一个 aa = nn.Linear(10,20)()

我得到了与上面相同的错误。

即使我遵循 Torch GitHub 上的 示例

我还是得到相同的错误。

更新解决

我忘记导入 nngraph 包了。虽然代码中两个 nn 都被称为 nn,但它们实际上是不同的包。

应该这样做

require 'nngraph'

而我只做了

require 'nn'
点赞
用户6792483
用户6792483

在这一行中,最后一个参数应该是 x3 而不是 x1:

l3 = nn.Linear(params.x3_size1, params.x3_size2)(x3)
2016-09-04 06:36:28