torch / lua: 从张量中检索 n-best 子集

现在有以下代码,它存储了 pred 中每个问题最高分数的索引,并将其转换为字符串。

我想为每个问题的 n-best 索引执行相同的操作,而不仅仅是单个具有最高分数的索引,并将它们转换为字符串。我还想显示每个索引(或每个转换后的字符串)的分数。

所以必须对 scores 进行排序,将 pred 更改为多行/列,而不是 1 x nqs。并且必须能够检索出 pred 中每个条目的相应 score 值。

我对 lua/torch 语法毫无头绪,非常感谢任何帮助。

nqs=dataset['question']:size(1);
scores=torch.Tensor(nqs,noutput);
qids=torch.LongTensor(nqs);
for i=1,nqs,batch_size do
    xlua.progress(i, nqs)
    r=math.min(i+batch_size-1,nqs);
    scores[{{i,r},{}}],qids[{{i,r}}]=forward(i,r);
end

tmp,pred=torch.max(scores,2);

answer=json_file['ix_to_ans'][tostring(pred[{i,1}])]
print(answer)
点赞
用户1866656
用户1866656

以下是我的尝试,在一个简单的随机scores张量中演示它的行为:

> scores=torch.floor(torch.rand(4,10)*100)
> =scores
 9   1  90  12  62   1  62  86  46  27
 7   4   7   4  71  99  33  48  98  63
 82   5  73  84  61  92  81  99  65   9
 33  93  64  77  36  68  89  44  19  25
[torch.DoubleTensor of size 4x10]

现在,因为您想要每个问题(行)的“N”个最佳索引,让我们对张量的每一行进行排序:

> values,indexes=scores:sort(2)

现在,让我们看看返回的张量包含什么内容:

> =values
  1   1   9  12  27  46  62  62  86  90
  4   4   7   7  33  48  63  71  98  99
  5   9  61  65  73  81  82  84  92  99
  19  25  33  36  44  64  68  77  89  93
  [torch.DoubleTensor of size 4x10]

> =indexes
  2   6   1   4  10   9   5   7   8   3
  2   4   1   3   7   8  10   5   9   6
  2  10   5   9   3   7   1   4   6   8
  9  10   1   5   8   3   6   4   7   2
  [torch.LongTensor of size 4x10]

正如您所看到的,values的第i行是scores的第i行的排序版本(按递增顺序排列),indexes中的每一行都会给您相应的索引。

你可以得到每个问题(即行)的N个最佳值/索引

> N_best_indexes=indexes[{{},{indexes:size(2)-N+1,indexes:size(2)}}]
> N_best_values=values[{{},{values:size(2)-N+1,values:size(2)}}]

让我们看看对于给定的示例,N=3时它们的值:

> return N_best_indexes
 7  8  3
 5  9  6
 4  6  8
 4  7  2
[torch.LongTensor of size 4x3]

> return N_best_values
 62  86  90
 71  98  99
 84  92  99
 77  89  93
[torch.DoubleTensor of size 4x3]

因此,第j个问题的第k个最佳值为N_best_values[{{j},{values:size(2)-k+1}}],并且它在scores矩阵中的相应索引由此行,列值给出:

行= j
列= N_best_indexes[{{j},indexes:size(2)-k+1}}].

例如,第二个问题的第一个最佳值(k=1)为99,它在scores的第2行和第6列。您可以看到,values[{{2},values:size(2)}}]99,而indexes[{{2},{indexes:size(2)}}]会给出相应的列索引6。

希望我解释了我的解决方案。

2017-07-30 11:31:17