从一维张量中提取前k个值的索引。

在 Torch ( torch.Tensor ) 中给定一个包含可比较值 (例如浮点数) 的 1-D 张量,我们如何提取该张量中前 k 个值的索引?

除了暴力方法外,我正在寻找一些 Torch/lua 提供的 API 调用,可以高效地执行此任务。

点赞
用户350664
用户350664

只需遍历张量并运行比较即可:

require 'torch'

data = torch.Tensor({1,2,3,4,505,6,7,8,9,10,11,12})
idx  = 1
max  = data[1]

for i=1,data:size()[1] do
   if data[i]>max then
      max=data[i]
      idx=i
   end
end

print(idx,max)

--EDIT-- 回复您的编辑:使用torch.max操作,该操作在此处有文档记录:https://github.com/torch/torch7/blob/master/doc/maths.md#torchmaxresval-resind-x-dim ...

y,i = torch.max(x,1)返回x中每列(跨行)的最大元素以及相应的索引张量i
2016-01-12 17:53:35
用户1688185
用户1688185

截至 pull request #496,现在 Torch 包括了一个内置 API,名为 torch.topk。例如:

> t = torch.Tensor{9, 1, 8, 2, 7, 3, 6, 4, 5}

-- 获取 3 个最小的元素
> res = t:topk(3)
> print(res)
 1
 2
 3
[torch.DoubleTensor of size 3]

-- 除此之外,你还可以同时获取索引
> res, ind = t:topk(3)
> print(ind)
 2
 4
 6
[torch.LongTensor of size 3]

-- 或者你也可以获取 k 个最大的元素,方法如下
-- (具体详情请看 API 文档)
> res = t:topk(3, true)
> print(res)
 9
 8
 7
[torch.DoubleTensor of size 3]

目前的 CPU 实现采用了一个排序和筛选的方法(未来计划进行改进)。 话虽如此,目前也正在 审核 一个专门针对 cutorch 优化的 GPU 实现。

2016-01-13 08:57:29
用户3650983
用户3650983

你可以使用topk函数。

例如:

import torch

t = torch.tensor([5.7, 1.4, 9.5, 1.6, 6.1, 4.3])

values,indices = t.topk(2)

print(values)
print(indices)

结果如下:

tensor([9.5000, 6.1000])
tensor([2, 4])
2019-04-22 23:38:32