Data/Machine learning

What does .eval() do in PyTorch?

DS-Lee 2021. 4. 14. 18:21

.eval() is known to be used during inference of models that usually contain BN and/or Dropout. When .eval() is used, the model with BN uses runinng_mean and running_var instead of mean and var obtained from each mini-batch.

Fine-tuning

When fine-tuning, it is important to use running_mean and running_var of the trained model and they should be fixed during fine-tuning. This is because usually the model is trained on a big dataset and the running_mean and running_var are accumulated through the big dataset. Therefore, they carry more "global" information. To summarize, the ideal approach for fine-tuning is:

  • running_mean and running_var of the trained model are used, and they are not updated during fine-tuning.
  • Yet, weights and biases in the model should be updated including BN's weights and biases.

The above things can be implemented by doing the followings (In case of the linear evaluation):

  1. Load the trained model (e.g., encoder), and add a linear layer on the top. 
  2. Set encoder.eval(), while linear_layer.train()
  3. Fine-tune.

In setting up an optimizer, you put parameters of both encoder and linear_layer. You can provide different learning rates to encoder's parameters and linear_layer's parameters such as 1e-4 and 1e-3 to each of them, respectively (reference).