如何在torch中向图形模块添加附加层

如何在torch的nngraph包中向图形模块(gModule)添加新节点?我尝试使用add函数,这将节点添加到gModules对象中的module插槽中。然而,输出仍然来自于先前的最后一个节点。

简化代码:

require" nn"
require" nngraph"

--构建gModule的功能
function buildModule(input_size,hidden_size)
    local x = nn.Identity()()
    local out = x-nn.Linear(input_size,hidden_size)-nn.Tanh()
    return nn.gModule({x},{out})
end

network = buildModule(5,3)
--要添加的附加层
l2 = nn.Linear(3,10)
network:add(l2)

--预期尺寸为10的张量,但获得了一个尺寸为3的张量
print(network:forward(torch.randn(5)))
点赞
用户2658050
用户2658050

gModule 实际上不应该被改变。它支持 :add 的事实实际上是它作为 nn.Container 的子类的副作用,而不是设计决策。通常情况下,一旦创建了 gModule ,就不应该修改它的内部结构,因为你需要修改一些内部属性才能使所有工作正常。相反,如果你想要“在顶部”添加一些东西,只需定义一个新的容器,将前一个容器作为输入。

-- 构建一个 gModule 的函数
function buildModule(input_size,hidden_size)
    local x = nn.Identity()() -- 输入层
    local out = x - nn.Linear(input_size,hidden_size) - nn.Tanh() -- 隐藏层和输出层
    return nn.gModule({x},{out}) -- 返回 gModule
end

network = buildModule(5,3) -- 构建网络

new_network = nn.Sequential() -- 定义新的容器
new_network:add(network) -- 添加网络
new_network:add(nn.Linear(3,10)) -- 在顶部添加一些东西
2016-11-13 16:31:34