在Torch中添加自定义损失函数

在Torch中实现自定义损失函数的必要步骤是什么?

看起来你需要编写updateOutput和updateGradInput的实现。

就这样吗?那么你基本上创建了一个新类:

local CustomCriterion, parent =   torch.class('CustomCriterion','nn.Criterion')

并实现了以下两个函数:

function CustomCriterion:updateOutput(input, target)
function CustomCriterion:updateGradInput(input, target)

这样正确吗,或者还有更多需要做的?

此外,对于这些提供的准则,这些函数是用C实现的,但我认为Lua实现也可以,尽管可能会慢一点?

点赞
用户1866656
用户1866656

我已经按照伪代码实现了形式为

--假设输入被划分为input_a、input_b
--         目标被相应地划分为target_a、target_b
f(input)=MSE(input_a,target_a)+ custom_sutff(input_b,target_b)

的函数很多次,就像你描述的那样。所以据我所知,我认为两个问题的答案都是肯定的。

基本上,nn/MSECriterion.luathis 似乎支持这种做法。

2017-06-26 22:49:02