torch7: 过滤NaN值

给定任何一般的 float torch.Tensor,可能包含一些 NaN 值,我正在寻找一个有效的方法来用零替换其中的所有 NaN 值,或者将它们全部删除并在另一个新的 Tensor 中过滤出 "有用" 值。

我知道一种简单的方法是手动迭代给定张量中的所有值,并相应地将它们替换为零或在新张量中拒绝它们。

是否存在一些预定义的 Torch 函数或函数组合,可以更有效地从性能上实现这一点,并依赖 Torch 的固有 CPU-GPU 优化?

点赞
用户4850610
用户4850610

看起来 torch 中没有检查张量是否为 NaN 的函数。但是由于 NaN != NaN,可以这样解决:

a = torch.rand(4, 5)
a[2][3] = tonumber('nan')
nan_mask = a:ne(a)
notnan_mask = a:eq(a)

print(a)
 0.2434  0.1731  0.3440  0.3340  0.0519
 0.0932  0.4067  nan     0.1827  0.5945
 0.3020  0.1035  0.5415  0.3329  0.7881
 0.6108  0.9498  0.0406  0.9335  0.3582
[torch.DoubleTensor of size 4x5]

print(nan_mask)
 0  0  0  0  0
 0  0  1  0  0
 0  0  0  0  0
 0  0  0  0  0
[torch.ByteTensor of size 4x5]

有了这些掩码,可以高效地提取 NaN/非 NaN 值并用任何想要的值替换它们:

print(a[notnan_mask])
...
[torch.DoubleTensor of size 19]

a[nan_mask] = 42
print(a)
  0.2434   0.1731   0.3440   0.3340   0.0519
  0.0932   0.4067  42.0000   0.1827   0.5945
  0.3020   0.1035   0.5415   0.3329   0.7881
  0.6108   0.9498   0.0406   0.9335   0.3582
[torch.DoubleTensor of size 4x5]
2016-05-12 21:18:12