Data/Machine learning
Access 'Decayed learning rate' in TF
DS-Lee
2020. 10. 9. 09:50
We assume that we use tf.keras.optimizers.schedules.ExponentialDecay
and we'd like to print current decayed learning rate
using the Callback
.
Normally, we can print the learning rate
using the following Callback
class:
class LearningRateTracker(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs=None):
print("current lr: ", self.model.optimizers.learning_rate)
However, the current decayed learning rate
cannot be accessed like that.
To print out current decayed learning rate
, you should use the following Callback
class:
class LearningRateTracker(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs=None):
current_decayed_lr = self.model.optimizer._decayed_lr(tf.float32).numpy()
print("current decayed lr: {:0.7f}".format(current_decayed_lr))
Reference