admin管理员组

文章数量:1530845

参考RuntimeError: Trying to backward through the graph a second time... - 云+社区 - 腾讯云

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

torch.autograd.backward

torch.autograd.backward(tensors, grad_tensors=None, retain_graph=None, create_graph=False, grad_variables=None)
  • retain_graph (bool, optional) – If False, the graph used to compute the grad will be freed. Note that in nearly all cases setting this option to True is not needed and often can be worked around in a much more efficient way.Defaults to the value of create_graph.
  • create_graph (bool, optional) – If True, graph of the derivative will be constructed, allowing to compute higher order derivative products.Defaults to False.

retain_graph = True (when to use it?)

retain_graph这个参数在平常中我们是用不到的,但是在特殊的情况下我们会用到它:

  1. 一个网络有两个output分别执行backward进行回传的时候: output1.backward(), output2.backward().
  2. 一个网络有两个loss需要分别执行backward进行回传的时候: loss1.backward(), loss1.backward().

以情况2.为例
如果代码这样写,就会出现博文开头的参数:

loss1.backward()
loss2.backward()

正确代码:

loss1.backward(retain_graph=True) #保留backward后的中间参数。
loss2.backward() #所有中间变量都会被释放,以便下一次的循环
optimizer.step() # 更新参数

retain_graph参数为True去保留中间参数从而两个loss的backward()不会相互影响。

补充:两个网络的两个loss需要分别执行backward进行回传的时候: loss1.backward(), loss1.backward().

#两个网络的情况需要分别为两个网络分别定义optimizer
optimizer1= torch.optim.SGD(net1.parameters(), learning_rate, momentum,weight_decay)
optimizer2= torch.optim.SGD(net2.parameters(), learning_rate, momentum,weight_decay)
.....
#train 部分的loss回传处理
loss1 = loss()
loss2 = loss()

optimizer1.zero_grad() #set the grade to zero
loss1.backward(retain_graph=True) #保留backward后的中间参数。
optimizer1.step()

optimizer2.zero_grad() #set the grade to zero
loss2.backward()
optimizer2.step()

scheduler = torch.optim.lr_scheduler.StepLR(
附录:

步骤解释

  • optimizer.zero_grad()

将梯度初始化为零
(因为一个batch的loss关于weight的导数是所有sample的loss关于weight的导数的累加和)
对应d_weights = [0] * n

  • output = net(inputs)

前向传播求出预测的值

  • loss = Loss(outputs, labels)

求loss

  • loss.backward()

反向传播求梯度
对应d_weights = [d_weights[j] + (label[k] - output ) * input[k][j] for j in range(n)]

  • optimizer.step()

更新所有参数
对应weights = [weights[k] + alpha * d_weights[k] for k in range(n)]

本文标签: RuntimeErrorGraphTIME