Data/Machine learning

Access 'Decayed learning rate' in TF

DS-Lee 2020. 10. 9. 09:50

We assume that we use tf.keras.optimizers.schedules.ExponentialDecay and we'd like to print current decayed learning rate using the Callback.

Normally, we can print the learning rate using the following Callback class:

class LearningRateTracker(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        print("current lr: ", self.model.optimizers.learning_rate)

However, the current decayed learning rate cannot be accessed like that.
To print out current decayed learning rate, you should use the following Callback class:

class LearningRateTracker(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        current_decayed_lr = self.model.optimizer._decayed_lr(tf.float32).numpy()
        print("current decayed lr: {:0.7f}".format(current_decayed_lr))

 

Reference