如何在torch7中进行多任务学习?

这里可以做简单的多任务网络。 但我想要像这个一样的东西 输入图像描述。 现在我按照以下方式构建模型:

model = nn.Sequential()
model:add(nn.Linear(3,5))
prl1 = nn.ConcatTable()
prl1:add(nn.Linear(5,1))
prl2 = nn.ConcatTable()
prl2:add(nn.Linear(5,1))
prl2:add(nn.Linear(5,1))
prl1:add(prl2)
model:add(prl1)

我的输出是:

input = torch.rand(5,3)
output = model:forward(input)
output
{
  1 : DoubleTensor - size: 5x1
  2 :
    {
      1 : DoubleTensor - size: 5x1
      2 : DoubleTensor - size: 5x1
    }
}

我应该如何构建我的标准?

点赞
用户5995434
用户5995434

我似乎通过两个步骤找到了解决方法:

1.在上述网络中使用nn.Concat而非nn.ConcatTable,这会使输出变成一个简单的NxM张量,例如在使用nn.Concat而非nn.ConcatTable的情况下,一个5x3张量将进入上述网络。

2.在获得NxM张量之后,我使用nn.ConcatTable、nn.Concat和nn.Select的组合来使输出成为包含每个结果张量的简单表格。

以下是第二步的一个简单示例:

model = nn.Sequential()
model:add(nn.Linear(3,5))

prl = nn.ConcatTable()

spl1 = nn.Concat(2)

seq1 = nn.Sequential()
seq1:add(nn.Select(2, 1))
seq1:add(nn.Reshape(1))

seq2 = nn.Sequential()
seq2:add(nn.Select(2, 2))
seq2:add(nn.Reshape(1))

seq3 = nn.Sequential()
seq3:add(nn.Select(2, 3))
seq3:add(nn.Reshape(1))

spl1:add(seq1)
spl1:add(seq2)
spl1:add(seq3)
prl:add(spl1)

spl2 = nn.Concat(2)

seq4 = nn.Sequential()
seq4:add(nn.Select(2, 4))
seq4:add(nn.Reshape(1))

seq5 = nn.Sequential()
seq5:add(nn.Select(2, 5))
seq5:add(nn.Reshape(1))

spl2:add(seq4)
spl2:add(seq5)
prl:add(spl2)

model:add(prl)

input = torch.rand(5,3)
output = model:forward(input)

输出将是:

th> output
{
  1 : DoubleTensor - size: 5x3
  2 : DoubleTensor - size: 5x2
}
2017-12-14 03:03:56