Data/Machine learning

[PyTorch] .detach()

DS-Lee 2021. 1. 28. 16:42

Tensor가 기록을 추적하는 것을 중단하게 하려면, .detach()를 호출하여 연산 기록으로부터 분리(detach)하여 이후 연산들이 추적되는 것을 방지할 수 있습니다. (출처)

Example (Source: here)

modelA = nn.Linear(10, 10)
modelB = nn.Linear(10, 10)
modelC = nn.Linear(10, 10)

x = torch.randn(1, 10)
a = modelA(x)
b = modelB(a.detach())
b.mean().backward()
print(modelA.weight.grad)
print(modelB.weight.grad)

Note that the derivatives w.r.t weights in modelA are not computed.

c = modelC(a)
c.mean().backward()
print(modelA.weight.grad)
print(modelC.weight.grad)

Note that the derivatives w.r.t weights in modelA are computed.