admin管理员组文章数量:1530845
参考RuntimeError: Trying to backward through the graph a second time... - 云+社区 - 腾讯云
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
torch.autograd.backward
torch.autograd.backward(tensors, grad_tensors=None, retain_graph=None, create_graph=False, grad_variables=None)
- retain_graph (bool, optional) – If False, the graph used to compute the grad will be freed. Note that in nearly all cases setting this option to True is not needed and often can be worked around in a much more efficient way.Defaults to the value of create_graph.
- create_graph (bool, optional) – If True, graph of the derivative will be constructed, allowing to compute higher order derivative products.Defaults to False.
retain_graph = True (when to use it?)
retain_graph这个参数在平常中我们是用不到的,但是在特殊的情况下我们会用到它:
- 一个网络有两个output分别执行backward进行回传的时候: output1.backward(), output2.backward().
- 一个网络有两个loss需要分别执行backward进行回传的时候: loss1.backward(), loss1.backward().
以情况2.为例
如果代码这样写,就会出现博文开头的参数:
loss1.backward()
loss2.backward()
正确代码:
loss1.backward(retain_graph=True) #保留backward后的中间参数。
loss2.backward() #所有中间变量都会被释放,以便下一次的循环
optimizer.step() # 更新参数
retain_graph参数为True去保留中间参数从而两个loss的backward()不会相互影响。
补充:两个网络的两个loss需要分别执行backward进行回传的时候: loss1.backward(), loss1.backward().
#两个网络的情况需要分别为两个网络分别定义optimizer
optimizer1= torch.optim.SGD(net1.parameters(), learning_rate, momentum,weight_decay)
optimizer2= torch.optim.SGD(net2.parameters(), learning_rate, momentum,weight_decay)
.....
#train 部分的loss回传处理
loss1 = loss()
loss2 = loss()
optimizer1.zero_grad() #set the grade to zero
loss1.backward(retain_graph=True) #保留backward后的中间参数。
optimizer1.step()
optimizer2.zero_grad() #set the grade to zero
loss2.backward()
optimizer2.step()
scheduler = torch.optim.lr_scheduler.StepLR(
附录:
步骤解释
- optimizer.zero_grad()
将梯度初始化为零
(因为一个batch的loss关于weight的导数是所有sample的loss关于weight的导数的累加和)
对应d_weights = [0] * n
- output = net(inputs)
前向传播求出预测的值
- loss = Loss(outputs, labels)
求loss
- loss.backward()
反向传播求梯度
对应d_weights = [d_weights[j] + (label[k] - output ) * input[k][j] for j in range(n)]
- optimizer.step()
更新所有参数
对应weights = [weights[k] + alpha * d_weights[k] for k in range(n)]
本文标签: RuntimeErrorGraphTIME
版权声明:本文标题:RuntimeError: Trying to backward through the graph a second time... 内容由热心网友自发贡献,该文观点仅代表作者本人, 转载请联系作者并注明出处:https://m.elefans.com/dongtai/1725805826a1044016.html, 本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌抄袭侵权/违法违规的内容,一经查实,本站将立刻删除。
发表评论