-
[PyTorch] .detach()Data/Machine learning 2021. 1. 28. 16:42
Tensor가 기록을 추적하는 것을 중단하게 하려면,
.detach()
를 호출하여 연산 기록으로부터 분리(detach)하여 이후 연산들이 추적되는 것을 방지할 수 있습니다. (출처)Example (Source: here)
modelA = nn.Linear(10, 10) modelB = nn.Linear(10, 10) modelC = nn.Linear(10, 10) x = torch.randn(1, 10) a = modelA(x) b = modelB(a.detach()) b.mean().backward() print(modelA.weight.grad) print(modelB.weight.grad)
c = modelC(a) c.mean().backward() print(modelA.weight.grad) print(modelC.weight.grad)
'Data > Machine learning' 카테고리의 다른 글
Dilated Causal Convolution from WaveNet (0) 2021.03.01 [PyTorch] .detach() in Loss Function (0) 2021.01.28 The Boosting Algorithm (0) 2021.01.12