Torch:如何更新模型参数?
2017-8-22 21:16:44
收藏:0
阅读:92
评论:1
这里是一个玩具模型。我在调用 backward 前仅打印了一次模型参数,然后再次打印模型参数。结果发现参数没有改变。如果在调用 backward 后添加了 model:updateParameters(<learning_rate>) 这行代码,就能看到参数被更新了。
但是在例子代码中,比如 https://github.com/torch/demos/blob/master/train-a-digit-classifier/train-on-mnist.lua,没有人真正调用过 updateParameters。另外,似乎 optim.sgd、optim.adam 和 nn.StochasticGradient 也从来没有调用过 updateParameters。我错过了什么?参数是如何自动更新的?如果我必须调用 updateParameters,为什么没有例子这样做?
require 'nn'
require 'optim'
local model = nn.Sequential()
model:add(nn.Linear(4, 1, false))
local params, grads = model:getParameters()
local criterion = nn.MSECriterion()
local inputs = torch.randn(1, 4)
local labels = torch.Tensor{1}
print(params)
model:zeroGradParameters()
local output = model:forward(inputs)
local loss = criterion:forward(output, labels)
local dfdw = criterion:backward(output, labels)
model:backward(inputs, dfdw)
-- 将以下行注释去掉,参数将被更新:
-- model:updateParameters(1000)
print(params)
点赞
评论区的留言会收到邮件通知哦~
推荐文章
- Lua 虚拟机加密load(string.dump(function)) 后执行失败问题如何解决
- 我想创建一个 Nginx 规则,禁止访问
- 如何将两个不同的lua文件合成一个 东西有点长 大佬请耐心看完 我是小白研究几天了都没搞定
- 如何在roblox studio中1:1导入真实世界的地形?
- 求解,lua_resume的第二次调用继续执行协程问题。
- 【上海普陀区】内向猫网络招募【Skynet游戏框架Lua后端程序员】
- SF爱好求教:如何用lua实现游戏内调用数据库函数实现账号密码注册?
- Lua实现网站后台开发
- LUA错误显式返回,社区常见的规约是怎么样的
- lua5.3下载库失败
- 请问如何实现文本框内容和某个网页搜索框内容连接,并把网页输出来的结果反馈到另外一个文本框上
- lua lanes多线程使用
- 一个kv数据库
- openresty 有没有比较轻量的 docker 镜像
- 想问一下,有大佬用过luacurl吗
- 在Lua执行过程中使用Load函数出现问题
- 为什么 neovim 里没有显示一些特殊字符?
- Lua比较两个表的值(不考虑键的顺序)
- 有个lua简单的项目,外包,有意者加微信 liuheng600456详谈,最好在成都
- 如何在 Visual Studio 2022 中运行 Lua 代码?

backward()函数不应该改变参数,它只是计算网络所有参数对于误差函数的导数。通常,训练的步骤如下:
repeat local output = model:forward(input) --查看模型的预测结果 local loss = criterion:forward(output, answer) --查看错误率 local loss_grad = criterion:backward(output, answer) --查看最错误的位置 model:backward(input,loss_grad) --查看每个参数对误差的贡献程度 model:updateParameters(learningRate) --根据错误情况修正参数 model:zeroGradParameters() --由于网络参数已经变化,老的梯度无用了 until is_user_satisfied()updateParameters实现了最简单的优化算法(梯度下降)。如果想要,可以自己写函数代替它。在理论上,可以显式遍历整个网络来更新参数。在实际操作中,通常调用getParameters()。local model_parameters,model_parameters_gradient=model:getParameters()这会返回所有参数的均匀张量及其梯度。这些张量是网络中的视图,因此对它们进行更改会影响网络。不一定知道网络中哪个点对应的是哪个值,但大多数优化器并不关心这一点。
optim.sgd用法的演示可以在demo中找到:具体内容在演示中有介绍,但在此处,重要的是优化器将
model_parameters作为参数接收,并具有对网络进行写操作的功能。尽管在文档中并没有明确说明,但在 source code 中可以看到,优化器会改变其输入张量的值(同时,注意它返回的是接收到的 相同的 张量)。