Torch - Lua / 获取矩阵中的最大索引

我正在尝试编写一个用于多米诺游戏的神经网络

输入是一个 8 x 8 x 3 的矩阵。我按以下方式组织矩阵:

第一个深度是游戏的状态,第二个深度是翻转的棋盘,最后一个深度是玩家的平面

输出是 8 x 8 的最佳游戏(由 Monte Carlo Tree Search 生成)

然后网络是一个 8 x 8 的张量,其概率为成为最佳游戏,我需要获取张量中最大概率的索引 (x,y)

我尝试使用函数 torch.max(tensor, 2) 和 torch.max(tensor?1),但我没有得到我需要的结果。

有没有人有任何线索可以帮助我?

非常感谢!

#out =  神经网络的输出,output 是目标输出 [indice][1] needs to check if the target is the same as prediction
max, bestTarget = torch.max(output[index][1],2)
maxP, bestPrediction = torch.max(out,2)
max, indT = torch.max(max,1)
maxP, indP = torch.max(maxP,1)
点赞
用户1235026
用户1235026

为了得到 out 的最大元素(best_row, best_col):

-- 首先计算每一行的最大元素以及对应的下标
maxP_per_row, bestColumn_per_row = torch.max(out,2)
-- 然后得到 best\_row 对应的最大元素和下标
best_p, best_row = torch.max(maxP_per_row, 1)
-- 然后找到 best\_row 所在的最优列
best_col = bestColumn_per_row[best_row]

你可以用同样的方法得到 target 的最大元素和下标。希望这有帮助。

2017-03-22 22:03:36