如何在Torch nn包中禁用omp?

具体来说,我希望 nn.LogSoftMax 在输入张量的大小很小时不使用 omp。我有一个小脚本来测试运行时长。

如果 arg[1] 是 10,那么我基本的对数 Softmax 函数运行得更快:

0.00021696090698242
0.033425092697144

但是一旦 arg [1] 是 10,000,000,omp 确实会有很大帮助:

29.561321973801
0.11547803878784

所以我怀疑 omp 的开销非常高。如果我的代码需要多次调用 log softmax 小输入(比如张量大小只有 3),那么它将花费太多时间。是否有一种方法在某些情况下手动禁用 omp 的使用(但并不总是禁用)?

点赞
用户1688185
用户1688185

有没有办法在某些情况下手动禁用 OMP 使用(但不总是)?

如果你真的想这样做,一种可能的方式是使用 torch.setnumthreadstorch.getnumthreads,如下所示:

local nth = torch.getnumthreads()
torch.setnumthreads(1)
-- 做一些事情
torch.setnumthreads(nth)

那么你就可以像这样猴子补丁 nn.LogSoftMax:

nn.LogSoftMax.updateOutput = function(self, input)
  local nth = torch.getnumthreads()
  torch.setnumthreads(1)
  local out = input.nn.LogSoftMax_updateOutput(self, input)
  torch.setnumthreads(nth)
  return out
end
2015-05-21 08:12:31