athena.utils.checkpoint¶
check point manager
Module Contents¶
Classes¶
Checkpoint |
A wrapper for Tensorflow checkpoint |
-
class
athena.utils.checkpoint.Checkpoint(checkpoint_directory=None, model=None, **kwargs)¶ Bases:
tensorflow.train.CheckpointA wrapper for Tensorflow checkpoint
Parameters: - checkpoint_directory – the directory for checkpoint
- summary_directory – the directory for summary used in Tensorboard
- provide the optimizer and model (__init__) –
- save the model (__call__) –
Example
transformer = SpeechTransformer(target_vocab_size=dataset_builder.target_dim) optimizer = tf.keras.optimizers.Adam() ckpt = Checkpoint(checkpoint_directory=’./train’, summary_directory=’./event’,
transformer=transformer, optimizer=optimizer)solver = BaseSolver(transformer) for epoch in dataset:
ckpt()-
_compare_and_save_best(self, loss, metrics, save_path)¶ compare and save the best model with best_loss and N best metrics
-
compute_nbest_avg(self, model_avg_num)¶ restore n-best avg checkpoint
-
__call__(self, loss=None, metrics=None)¶
-
restore_from_best(self)¶ restore from the best model