Lua Torch 中类似于 np.where() 的等效函数?

我有一个 ByteTensor,想要找到其中元素为 1 的索引。在 numpy 中,我可以这样做

a = np.array([1,0,1,0,1])
return np.where(a)

这将返回 (array([0, 2, 4]),)。在 Torch 中是否有此功能?

(在我的特定情况下,我想要使用这些索引来索引多个不同的 Tensor 对象,但知道如何在一般情况下做这件事也很好。)

点赞
用户1688185
用户1688185

你可以使用 torch.nonzero 函数,例如:

> a = torch.ByteTensor{1,0,1,0,1}
> print(torch.nonzero(a))
 1
 3
 5
[torch.LongTensor of size 3x1]

如果你只需要寻找值为 1 的元素,可以使用逻辑运算符:

> a = torch.ByteTensor{1,2,1,6,1}
> a:eq(1):nonzero()
2016-03-09 08:33:11