在 Torch 中拆分层后序列化神经网络大小

训练神经网络(例如多层感知机)后,在预测时我想将第一层从其他层中分离出来。为了做到这一点,我找到的唯一方法,以便拥有正确大小的文件,如下所示:

循环遍历所有层,并将它们添加到两个容器之一(第一层其他所有层)中,然后使用torch.save函数分别保存。有趣的是,在将它们添加到两个容器之一之前,我需要检索每个层的参数,否则,当保存时,两个文件(第一层其他所有层)具有相同的文件大小。

下面的代码片段比我以前的所有解释都更有帮助:

local function split_model(network)
    -- 由于某种原因,所有模型在保存时都具有相同的大小
    -- 如果不拆分,先调用 'getParameters()'。
    first_layer = nn.Sequential()
    all_the_rest = nn.Sequential()
    for i = 1,network:size() do
        local l = network:get(i)
        local l_params,_ = l:getParameters()
        if i == 1 then
            first_layer:add(l)
        else
            all_the_rest:add(l)
        end
    end
    return first_layer,all_the_rest
end

local first_layer,all_the_rest = split_model(network)
torch.save("checkpoints/mlp.t7",.network)
torch.save("checkpoints/first_layer.t7",first_layer)
torch.save("checkpoints/all_the_rest.t7",all_the_rest)
点赞
用户1522304
用户1522304

以下是 Alban Desmaison 在 Google groups 中针对同一个问题的回答:

你好,

这种行为的原因是由于 getParameters 工作方式引起的。为了能够返回包含所有参数的平坦的张量,它实际上会创建一个包含所有权重的单一存储,并且每个模块的权重都是此存储的一部分。当您保存网络中任何元素的权重时,它将必须保存权重张量,并保存底层存储。因此,如果您在整个网络上调用了 getParameters,如果保存了任何模块,您将保存所有网络的权重。在这里,当您在单个模块上调用 getParameters 时,它实际上会重新创建此单个存储,但仅针对此单个模块,因此当您保存它时,它仅包含您想要的权重。但是请注意,您在完整网络上执行的 getParameters 返回的平铺参数不再有效!!!

在这里有两个解决方案: - 如果您不想使用从整个网络获得的参数,则可以在保存每个子集之前仅调用每个子集的 getParameters。这将打破潜在的存储,只包含此网络子集,您将仅保存所需内容(共享存储只存一次)。 - 如果您希望能够继续使用原始 getParameters 中的参数,则可以执行与上述相同的操作,但使用它们的克隆版本进行 getParameters 和保存。

因为代码片段更易于理解:

require 'nn'

local subset1 = nn.Linear(2,2)
local subset2 = nn.Linear(2,2)

local network = nn.Sequential():add(subset1):add(subset2)

print("Before getParameters:", subset1.weight:storage():size()) -- 4 elements
network_params,_ = network:getParameters()
print("After getParameters:", subset1.weight:storage():size()) -- 12 elements
subset1.weight:random() -- 更改权重以查看是否仍然工作
print("network_params is valid?", network_params[1] == subset1.weight[1][1]) -- true

-- 保持 network_params 有效
local clone_subset1 = subset1:clone()
print("Cloned subset1 before getParameters:", clone_subset1.weight:storage():size()) -- 12 elements
clone_subset1:getParameters()
print("Cloned subset1 after getParameters:", clone_subset1.weight:storage():size()) -- 6 elements (4 weights + 2 bias)
subset1.weight:random() -- 更改权重以查看是否仍然工作
print("network_params is valid?", network_params[1] == subset1.weight[1][1]) -- true

-- 不保持 network_params 有效(应该更快)
local clone_subset1 = subset1:clone()
print("subset1 before getParameters:", subset1.weight:storage():size()) -- 12 elements
subset1:getParameters()
print("subset1 after getParameters:", subset1.weight:storage():size()) -- 6 elements (4 weights + 2 bias)
subset1.weight:random() -- 更改权重以查看是否仍然工作
print("network_params is valid?", network_params[1] == subset1.weight[1][1]) -- false
2017-02-28 16:06:46