ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • [PyTorch] .detach()
    Data/Machine learning 2021. 1. 28. 16:42

    Tensor가 기록을 추적하는 것을 중단하게 하려면, .detach()를 호출하여 연산 기록으로부터 분리(detach)하여 이후 연산들이 추적되는 것을 방지할 수 있습니다. (출처)

    Example (Source: here)

    modelA = nn.Linear(10, 10)
    modelB = nn.Linear(10, 10)
    modelC = nn.Linear(10, 10)
    
    x = torch.randn(1, 10)
    a = modelA(x)
    b = modelB(a.detach())
    b.mean().backward()
    print(modelA.weight.grad)
    print(modelB.weight.grad)

    Note that the derivatives w.r.t weights in modelA are not computed.

    c = modelC(a)
    c.mean().backward()
    print(modelA.weight.grad)
    print(modelC.weight.grad)

    Note that the derivatives w.r.t weights in modelA are computed.

     

     

     

     

    'Data > Machine learning' 카테고리의 다른 글

    Dilated Causal Convolution from WaveNet  (0) 2021.03.01
    [PyTorch] .detach() in Loss Function  (0) 2021.01.28
    The Boosting Algorithm  (0) 2021.01.12

    Comments