-
What does .eval() do in PyTorch?Data/Machine learning 2021. 4. 14. 18:21
.eval()
is known to be used during inference of models that usually contain BN and/or Dropout. When.eval()
is used, the model with BN usesruninng_mean
andrunning_var
instead ofmean
andvar
obtained from each mini-batch.Fine-tuning
When fine-tuning, it is important to use
running_mean
andrunning_var
of the trained model and they should be fixed during fine-tuning. This is because usually the model is trained on a big dataset and therunning_mean
andrunning_var
are accumulated through the big dataset. Therefore, they carry more "global" information. To summarize, the ideal approach for fine-tuning is:running_mean
andrunning_var
of the trained model are used, and they are not updated during fine-tuning.- Yet, weights and biases in the model should be updated including BN's weights and biases.
The above things can be implemented by doing the followings (In case of the linear evaluation):
- Load the trained model (e.g., encoder), and add a linear layer on the top.
- Set
encoder.eval()
, whilelinear_layer.train()
- Fine-tune.
In setting up an optimizer, you put parameters of both
encoder
andlinear_layer
. You can provide different learning rates toencoder
's parameters andlinear_layer
's parameters such as 1e-4 and 1e-3 to each of them, respectively (reference).'Data > Machine learning' 카테고리의 다른 글
Visualization of CNN (0) 2021.04.15 Git Tips (0) 2021.04.05 W&B Tips (0) 2021.04.04