在 Torch 中拆分层后序列化神经网络大小
2016-10-26 12:35:55
收藏:0
阅读:74
评论:1
训练神经网络(例如多层感知机)后,在预测时我想将第一层从其他层中分离出来。为了做到这一点,我找到的唯一方法,以便拥有正确大小的文件,如下所示:
循环遍历所有层,并将它们添加到两个容器之一(第一层或其他所有层)中,然后使用torch.save函数分别保存。有趣的是,在将它们添加到两个容器之一之前,我需要检索每个层的参数,否则,当保存时,两个文件(第一层和其他所有层)具有相同的文件大小。
下面的代码片段比我以前的所有解释都更有帮助:
local function split_model(network)
-- 由于某种原因,所有模型在保存时都具有相同的大小
-- 如果不拆分,先调用 'getParameters()'。
first_layer = nn.Sequential()
all_the_rest = nn.Sequential()
for i = 1,network:size() do
local l = network:get(i)
local l_params,_ = l:getParameters()
if i == 1 then
first_layer:add(l)
else
all_the_rest:add(l)
end
end
return first_layer,all_the_rest
end
local first_layer,all_the_rest = split_model(network)
torch.save("checkpoints/mlp.t7",.network)
torch.save("checkpoints/first_layer.t7",first_layer)
torch.save("checkpoints/all_the_rest.t7",all_the_rest)
点赞
评论区的留言会收到邮件通知哦~
推荐文章
- Lua 虚拟机加密load(string.dump(function)) 后执行失败问题如何解决
- 我想创建一个 Nginx 规则,禁止访问
- 如何将两个不同的lua文件合成一个 东西有点长 大佬请耐心看完 我是小白研究几天了都没搞定
- 如何在roblox studio中1:1导入真实世界的地形?
- 求解,lua_resume的第二次调用继续执行协程问题。
- 【上海普陀区】内向猫网络招募【Skynet游戏框架Lua后端程序员】
- SF爱好求教:如何用lua实现游戏内调用数据库函数实现账号密码注册?
- Lua实现网站后台开发
- LUA错误显式返回,社区常见的规约是怎么样的
- lua5.3下载库失败
- 请问如何实现文本框内容和某个网页搜索框内容连接,并把网页输出来的结果反馈到另外一个文本框上
- lua lanes多线程使用
- 一个kv数据库
- openresty 有没有比较轻量的 docker 镜像
- 想问一下,有大佬用过luacurl吗
- 在Lua执行过程中使用Load函数出现问题
- 为什么 neovim 里没有显示一些特殊字符?
- Lua比较两个表的值(不考虑键的顺序)
- 有个lua简单的项目,外包,有意者加微信 liuheng600456详谈,最好在成都
- 如何在 Visual Studio 2022 中运行 Lua 代码?

以下是 Alban Desmaison 在 Google groups 中针对同一个问题的回答:
require 'nn' local subset1 = nn.Linear(2,2) local subset2 = nn.Linear(2,2) local network = nn.Sequential():add(subset1):add(subset2) print("Before getParameters:", subset1.weight:storage():size()) -- 4 elements network_params,_ = network:getParameters() print("After getParameters:", subset1.weight:storage():size()) -- 12 elements subset1.weight:random() -- 更改权重以查看是否仍然工作 print("network_params is valid?", network_params[1] == subset1.weight[1][1]) -- true -- 保持 network_params 有效 local clone_subset1 = subset1:clone() print("Cloned subset1 before getParameters:", clone_subset1.weight:storage():size()) -- 12 elements clone_subset1:getParameters() print("Cloned subset1 after getParameters:", clone_subset1.weight:storage():size()) -- 6 elements (4 weights + 2 bias) subset1.weight:random() -- 更改权重以查看是否仍然工作 print("network_params is valid?", network_params[1] == subset1.weight[1][1]) -- true -- 不保持 network_params 有效(应该更快) local clone_subset1 = subset1:clone() print("subset1 before getParameters:", subset1.weight:storage():size()) -- 12 elements subset1:getParameters() print("subset1 after getParameters:", subset1.weight:storage():size()) -- 6 elements (4 weights + 2 bias) subset1.weight:random() -- 更改权重以查看是否仍然工作 print("network_params is valid?", network_params[1] == subset1.weight[1][1]) -- false