What does .eval() do in PyTorch?
.eval()
is known to be used during inference of models that usually contain BN and/or Dropout. When .eval()
is used, the model with BN uses runinng_mean
and running_var
instead of mean
and var
obtained from each mini-batch.
Fine-tuning
When fine-tuning, it is important to use running_mean
and running_var
of the trained model and they should be fixed during fine-tuning. This is because usually the model is trained on a big dataset and the running_mean
and running_var
are accumulated through the big dataset. Therefore, they carry more "global" information. To summarize, the ideal approach for fine-tuning is:
running_mean
andrunning_var
of the trained model are used, and they are not updated during fine-tuning.- Yet, weights and biases in the model should be updated including BN's weights and biases.
The above things can be implemented by doing the followings (In case of the linear evaluation):
- Load the trained model (e.g., encoder), and add a linear layer on the top.
- Set
encoder.eval()
, whilelinear_layer.train()
- Fine-tune.
In setting up an optimizer, you put parameters of both encoder
and linear_layer
. You can provide different learning rates to encoder
's parameters and linear_layer
's parameters such as 1e-4 and 1e-3 to each of them, respectively (reference).