如何在torch中使用预定义顺序切片张量?

我有一个长度为10的数据集train = torch.range(1,10)。我想按照由p = torch.randperm(10)定义的随机顺序对其进行切片。

要按范围获取切片,可以执行a = train[{{1,3}}]以获取前三个元素。但是,假设我想获取第2、第3和第9个元素。我可以不使用像这样的for循环操作

for i = 1,3 do
  print(a[{ p[i] }])
end

其中

p[1] = 2, p[2] = 3, p[3] = 9.

'a = train[ p[{{1,3}}] ]'无法使用。

点赞
用户4687565
用户4687565

首先,有一个 index,但需要长张量:

train = torch.range(1,10)
p = torch.randperm(10):long()
print(train:index(p))
2017-04-16 17:32:44