Torch 'Gather' 问题

我有两个张量如下:

标准化张量:

1

10

94

[torch.LongStorage of size 3]

以及

批处理:

1

10

[torch.LongStorage of size 2]

我希望使用 'Batch' 在 '标准化张量' 的第三维中选择索引。 到目前为止,我已经使用 gather 如下所示:

normalised:long():gather(1, batch:long())

不幸的是,它返回以下错误。 “bad argument #1 to 'gather' (Input tensor must have same dimensions as output"

任何帮助将不胜感激!谢谢

点赞
用户4687565
用户4687565

这个答案假设以下情况:你有一个大小为x、y、z的三维张量,并且你想要一个大小为x、y、10的三维张量,其中x、y切片基于另一个大小为1、10的张量中列出的索引进行选择。

我个人花了很多时间思考“gather”方法的可能用途。唯一的结论是:它不适用于上述描述的问题。

可以使用“index”函数解决所述的问题:

local slice = normalised:gather(3, batch[1]:long())
2017-08-16 10:56:24