在torch中覆盖updateGradInput方法对我的自定义模块无效。

我为自己编写了一个名为 sharedModule 的模块。我已经重写了以下的 updateGradInput 方法:

function SharedModule:updateGradInput(input, gradOutput)
    test_grad = {}

    print("调用 updateGradInput")
    test_input = input
    test_gradOutput = gradOutput
    assert(type(gradOutput) == 'table' and #input == #gradOutput)

    local T = #input
    for t = 1, T do
            self.gradInput[t] = self.clones[t]: updateGradInput(input[t], gradOutput[t])
            test_grad[t] = self.gradInput[t]
    end
    print(#self.gradInput)  -- 打印正常值
    --self.gradInput = test_grad  --
    return self.gradInput  -- 空的,???
end

然而,当我在我的模块上调用 backward 方法时,self.gradInput 字段没有更新,有人能帮我解决这个问题吗?

点赞
用户5519228
用户5519228

backward 方法将调用两个名为 updateGradInputaccGradParams 的方法,结果发现错误出现在 accGradParams 中,因为有打字错误。

2015-11-03 14:41:50