在 Lua Torch 中,两个零矩阵的乘积会产生 NaN 值。

我在 Lua/Torch 中遇到了 torch.mm 函数的奇怪行为。下面是一个演示问题的简单程序。

程序由一个循环组成,在该循环中,程序将乘以两个零的 2x2 矩阵,并测试产品矩阵的条目 ent 是否等于 nan。看起来程序应该永远运行,因为乘积应该始终等于 0,因此 ent 应该为 0。然而,该程序打印了:

error at iteration 548
0.000000 0.000000
nan nan
[torch.DoubleTensor of size 2x2]

这是为什么呢?

更新:

  1. 如果我用 torch.mm(prod,a,b) 替换 _prod = torch.mm(a,b)_,则问题消失,这表明内存分配出了问题。

  2. 我的 Torch 版本未编译 BLAS 和 LAPACK 库。在我重新编译了使用 OpenBLAS 的 Torch 后,问题消失了。但是,我仍然对其原因感到有兴趣。

点赞
用户1688185
用户1688185

代码部分自动生成了Lua封装器,用于 torch.mm,其可在 此处 找到。

当您在循环中编写 prod = torch.mm(a,b) 时,对应于幕后生成的以下 C 代码(由该封装器生成,感谢 cwrap):

/* this is the tensor that will hold the results */
arg1 = THDoubleTensor_new();
THDoubleTensor_resize2d(arg1, arg5->size[0], arg6->size[1]);
arg3 = arg1;
/* .... */
luaT_pushudata(L, arg1, "torch.DoubleTensor");
/* effective matrix multiplication operation that will fill arg1 */
THDoubleTensor_addmm(arg1,arg2,arg3,arg4,arg5,arg6);

因此:

  • 创建一个新的结果张量,并将其大小调整为正确的尺寸,
  • 但是,此新张量未初始化,即没有 calloc 或显式填充,因此它指向垃圾内存,可能包含NaN值,
  • 将此张量推送到堆栈上,以便在Lua一侧作为返回值可用。

最后一点意味着这个返回的张量与初始的 prod 张量不同(即,在循环中,prod 遮盖了初始值)。

另一方面,调用 torch.mm(prod,a,b) _确实_使用了您的初始 prod 张量来存储结果(在幕后的情况下,在这种情况下没有必要创建专用张量)。由于在您的代码片段中未初始化/填充它,因此它也可能包含垃圾。

在这两种情况下,核心操作都是像 C = beta * C + alpha * A * B 这样的 gemm 乘法,其中beta = 0,alpha = 1。naive implementation看起来像这样:

  real *a_ = a;
  for(i = 0; i < m; i++)
  {
    real *b_ = b;
    for(j = 0; j < n; j++)
    {
      real sum = 0;
      for(l = 0; l < k; l++)
        sum += a_[l*lda]*b_[l];
      b_ += ldb;
      /*
       * WARNING: beta*c[j*ldc+i] could give NaN even if beta=0
       *          if the other operand c[j*ldc+i] is NaN!
       */
      c[j*ldc+i] = beta*c[j*ldc+i]+alpha*sum;
    }
    a_++;
  }

注释是我的。

所以:

  1. 使用 torch.mm(a,b):在每次迭代时,创建一个新的结果张量,而未初始化它(它可能包含 NaN 值)。因此,每次迭代都存在返回 NaN 的风险(见上方警告),
  2. 使用 torch.mm(prod,a,b):由于您未初始化 prod 张量,因此存在相同的风险。但是:此风险仅在重复/直到循环的第一次迭代时存在,因为在 prod 被填充为0后,后续迭代将重新使用它。

因此,这就是为什么您在此处没有观察到问题的原因(它的发生频率更低)。

在情况1中:在 Torch 层面应改进此问题,即确保封装器初始化输出(例如,使用 THDoubleTensor_fill (arg1,0))。

在情况2中:您应该最初初始化 prod,并使用 torch.mm(prod,a,b) 结构以避免任何 NaN 问题。


编辑:此问题现已解决(请参见此 pull request)。

2015-10-10 21:40:22