Data/Machine learning
[PyTorch] .detach()
DS-Lee
2021. 1. 28. 16:42
Tensor가 기록을 추적하는 것을 중단하게 하려면, .detach()
를 호출하여 연산 기록으로부터 분리(detach)하여 이후 연산들이 추적되는 것을 방지할 수 있습니다. (출처)
Example (Source: here)
modelA = nn.Linear(10, 10)
modelB = nn.Linear(10, 10)
modelC = nn.Linear(10, 10)
x = torch.randn(1, 10)
a = modelA(x)
b = modelB(a.detach())
b.mean().backward()
print(modelA.weight.grad)
print(modelB.weight.grad)
c = modelC(a)
c.mean().backward()
print(modelA.weight.grad)
print(modelC.weight.grad)