如何在Torch中将张量拆分为连续张量的表?

我正在尝试拆分张量,然后使用nn.ParallelTable在拆分的张量上做一些处理。但是,nn.SplitTable会使拆分张量不连续,这对我来说是不可取的。

这是我的网络体系结构:

nn.Sequential {
  [input -> (1) -> (2) -> output]
  (1): nn.SplitTable
  (2): nn.ParallelTable {
    input
      |`-> (1): nn.Sequential {
      |      [input -> (1) -> (2) -> (3) -> output]
      |      (1): nn.View(20, 1, 6, 6)
      |      (2): nn.SpatialConvolution(1 -> 10, 2x2)
      |      (3): nn.ReLU
      |    }
      |`-> (2): nn.Sequential {
      |      [input -> (1) -> (2) -> (3) -> output]
      |      (1): nn.View(20, 1, 6, 6)
      |      (2): nn.SpatialConvolution(1 -> 10, 2x2)
      |      (3): nn.ReLU
      |    }
       ... -> output
  }

下面是产生它的代码:

parallelLayer = nn.ParallelTable()
parallelLayer:add(nn.Sequential():add(nn.View(20, 1, 6, 6)):add(nn.SpatialConvolution(1, 10, 2, 2)):add(nn.ReLU()))
parallelLayer:add(nn.Sequential():add(nn.View(20, 1, 6, 6)):add(nn.SpatialConvolution(1, 10, 2, 2)):add(nn.ReLU()))

net = nn.Sequential():add(nn.SplitTable(2))
net:add(parallelLayer)

完成前向传递后: net:forward(torch.rand(20, 2, 6, 6))

我会得到以下错误:

/home/amir/torch/install/share/lua/5.1/torch/Tensor.lua:457: 预期得到连续的张量
stack traceback:
    [C]: in function 'assert'
    /home/amir/torch/install/share/lua/5.1/torch/Tensor.lua:457: in function 'view'
    /home/amir/torch/install/share/lua/5.1/nn/View.lua:83: in function </home/amir/torch/install/share/lua/5.1/nn/View.lua:77>
    [C]: in function 'xpcall'
    /home/amir/torch/install/share/lua/5.1/nn/Container.lua:65: in function 'rethrowErrors'
    /home/amir/torch/install/share/lua/5.1/nn/Sequential.lua:44: in function </home/amir/torch/install/share/lua/5.1/nn/Sequential.lua:41>
    [C]: in function 'xpcall'
    /home/amir/torch/install/share/lua/5.1/nn/Container.lua:65: in function 'rethrowErrors'
    /home/amir/torch/install/share/lua/5.1/nn/ParallelTable.lua:12: in function </home/amir/torch/install/share/lua/5.1/nn/ParallelTable.lua:10>
    [C]: in function 'xpcall'
    /home/amir/torch/install/share/lua/5.1/nn/Container.lua:65: in function 'rethrowErrors'
    /home/amir/torch/install/share/lua/5.1/nn/Sequential.lua:44: in function 'forward'
    [string "_RESULT={net:forward(torch.rand(20, 2, 6, 6))}"]:1: in main chunk
    [C]: in function 'xpcall'
    /home/amir/torch/install/share/lua/5.1/trepl/init.lua:651: in function 'repl'
    ...amir/torch/install/lib/luarocks/rocks/trepl/scm-1/bin/th:199: in main chunk
    [C]: at 0x00406670

但如果我执行:

net:get(2):forward({torch.rand(20, 6, 6), torch.rand(20, 6, 6)})

我不会得到任何错误。有没有人有任何关于如何将输入张量拆分为连续张量的想法?或者也许我应该问如何将张量连续拆分,然后使用ParallelTable?有没有更好的方法可以做到这一点我不知道的吗?

点赞