当克隆的网络被放置在较大的模块中时,复制品不会被训练。

我想通过将一个小的网络模块克隆并放置在几个不同的大型网络中来训练它。但是当我将它放入这样的一个环境中时,它的参数在训练过程中没有更新。下面是一个小例子,展示了这个问题。

我这样制造了这个小原始网络:

function mkPrimitive()
    local inp  = nn.Linear(2, 2)()
    local outp = nn.Tanh()(nn.Linear(2, 2)(nn.Tanh()(inp)))
    return nn.gModule({inp}, {outp})
end
prim = mkPrimitive()

然后我将它放到一个较大的网络中,称为 toTrain,就像这样:

function mkNet()
    local fst = prim:clone('weight', 'gradWeight', 'bias', 'gradBias')()
    local snd = prim:clone('weight', 'gradWeight', 'bias', 'gradBias')(fst)
    return nn.gModule({fst}, {snd})
end

toTrain = mkNet()

然后我训练了这个较大的网络,并打印出它的参数和 prim 的参数。我看到,在迭代过程中,较大的 toTrain 网络的参数改变了,而 prim 的没有。下面是训练代码。有没有方法解决这个问题?

numRuns = 10
function train()
    local crit = nn.MSECriterion()
    for i = 1, numRuns do
        toTrain:zeroGradParameters()

        local inData = torch.rand(1, 2)  --生成一些输入/输出数据
        local outData = torch.rand(1, 2)

        local pred = toTrain:forward(inData)
        local err = crit:forward(pred, outData)
        local grad = crit:backward(pred, outData)
        toTrain:backward(inData, grad)
        toTrain:updateParameters(0.01)

        local bigWs = toTrain:getParameters()
        local primWs = prim:getParameters()

        print(bigWs) --大网络的参数在学习过程中改变,
        print(primWs) --但原始网络的参数没有。
        print("------------------------------")
    end
end
train()
点赞
用户2104596
用户2104596

getParameters()会更改网络参数的内存位置,因此任何共享都会丢失。请查看此链接获取更多详情。

2016-02-01 16:16:26