在一轮epoch中,updater做的事情
1、其封装好的算法会根据所得loss一次性更新所有层
按照反传的方法
2、自己写的优化算法接收batchsize做为输入参数
在本地自己写好对于网络怎么优化的写法,即lr、params这种在本地和优化算法和网络放一起,这样优化算法就知道要优化哪些些东西、
![](D:\论文\截图\image-20220629232552738.png
updater放只能放一个接收batchsize的函数,这个函数中的变量和调用的具体优化算法要在本地实例化
==loss的backward中会给所有参数梯度==,(loss函数根据传进去loss函数进行)
所以只要自己的updater优化算法把要更新的参数含进去
(然后弄个==实例化的函数对==
==放进去的参数进行梯度下降==)就行了。
(自己初始化时候跟网络在一个文件中然后加到自定义里面)
这样在训练中就可以使用自定义的优化算法了
这也就是其代码的逻辑,用传进来的loss算损失,backward算梯度存在各自的参数中(每一层的直接就 l.backward)
然后写一个接收batchsize的updater==(用于把梯度除以batchsize,因为在使用自定义函数时用的是loss的sum的backward)==
在这个updater(本身是本地写的,跟net的参数在一个文件)中,本地化给==所有==参数给到 ==自定义写的优化函数==
在优化函数中。对所有参数进行梯度下降更新,这样就完成了梯度反传