这个函数:input.nn.MSECriterion_updateOutput(self, input, target) 在Lua/Torch中是如何工作的?

我有这个函数:

    function MSECriterion:updateOutput(input, target)
        return input.nn.MSECriterion_updateOutput(self, input, target)
    end

现在,

   input.nn.MSECriterion_updateOutput(self, input, target)

返回一个数字。我不知道它是如何做到的。我使用调试器逐步执行,似乎没有中间步骤,它就直接计算出了一个数字。

 input是大小为1的张量(比如说,-.234)。并且

 nn.MSECriterion_updateOutput(self, input, target)看起来就是函数MSECriterion:updateOutput(input, target)。

我对它如何计算一个数字感到困惑。

我对允许这样做感到困惑。参数input是一个张量,其并没有任何叫做nn.MSE input.nn.MSECriterion_updateOutput的方法。

点赞
用户1688185
用户1688185

当你执行 require "nn" 时,它会加载 init.lua,然后执行 require('libnn')。这是 torch/nn 的 C 扩展。

如果你查看 init.c,你可以找到 luaopen_libnn:当 libnn.sorequire 时,就会调用这个初始化函数。

这个函数负责初始化 torch/nn 的所有部分,包括通过 nn_FloatMSECriterion_init(L)nn_DoubleMSECriterion_init(L) 初始化 MSECriterion 的本地部分。

如果你查看 generic/MSECriterion.c,你可以找到通用(即宏扩展的 floatdouble初始化函数

static void nn_(MSECriterion_init)(lua_State *L)
{
  luaT_pushmetatable(L, torch_Tensor);
  luaT_registeratname(L, nn_(MSECriterion__), "nn");
  lua_pop(L,1);
}

这个初始化函数修改了任何 torch.FloatTensortorch.DoubleTensor 的元表,使得它填充了一堆函数在 nn 关键字下(更多详细信息请参见 Torch7 Lua C API)。这些函数在这之前被定义:

static const struct luaL_Reg nn_(MSECriterion__) [] = {
  {"MSECriterion_updateOutput", nn_(MSECriterion_updateOutput)},
  {"MSECriterion_updateGradInput", nn_(MSECriterion_updateGradInput)},
  {NULL, NULL}
};

换句话说,任何张量都有这些函数来自于它的元表:

luajit -lnn
> print(torch.Tensor().nn.MSECriterion_updateOutput)
function: 0x40921df8
> print(torch.Tensor().nn.MSECriterion_updateGradInput)
function: 0x40921e20

注意:这个机制对于所有具有 C 本地实现对应物的 torch/nn 模块相同。

因此,input.nn.MSECriterion_updateOutput(self, input, target) 的效果是调用 static int nn_(MSECriterion_updateOutput)(lua_State *L),就像你在 generic/MSECriterion.c 中看到的那样。

这个函数计算输入张量之间的均方误差。

2015-05-23 18:30:54