在 Torch Lua 代码中找不到 addmm 函数的定义

我正在理解在 Torch Lua 中实现的神经网络。在通过线性层的反向传递期间,它调用了一个名为 Linear:updateGradInput 的函数( https://github.com/torch/nn/blob/master/Linear.lua#L75 )

function Linear:updateGradInput(input, gradOutput)
  if self.gradInput then

     local nElement = self.gradInput:nElement()
     self.gradInput:resizeAs(input)
     if self.gradInput:nElement() ~= nElement then
        self.gradInput:zero()
     end
     if input:dim() == 1 then
        self.gradInput:addmv(0, 1, self.weight:t(), gradOutput)
     elseif input:dim() == 2 then
        self.gradInput:addmm(0, 1, gradOutput, self.weight)
     end
     return self.gradInput
  end
end

在该函数中,通过调用一个名为 addmm 的函数执行了一个基本的矩阵乘法操作(https://github.com/torch/nn/blob/master/Linear.lua#L86 )。我找不到这个 addmm 函数的定义。

在 TH 库中定义了一个 addmm 函数( https://github.com/torch/torch7/blob/master/lib/TH/generic/THTensorMath.c#L1282),但是我不确定 Lua 代码与 C 中的这段代码有何关联。

点赞
用户3852745
用户3852745

刚刚发现了 Lua 代码和 C 代码之间的联系。在 Lua 代码中对 addmm 的调用引导到这个函数(https://github.com/torch/torch7/blob/master/TensorMath.lua#L487-L510),然后这个函数调用在 C Torch Library 中定义的 addmm 函数,它在这里定义(https://github.com/torch/torch7/blob/master/lib/TH/generic/THTensorMath.c#L1282)。

这很棘手,因为 Lua 通过字符串构建对 C 函数的调用。

2018-05-29 07:08:22