athena.utils.checkpoint

check point manager

Module Contents

Classes

Checkpoint A wrapper for Tensorflow checkpoint
class athena.utils.checkpoint.Checkpoint(checkpoint_directory=None, model=None, **kwargs)

Bases: tensorflow.train.Checkpoint

A wrapper for Tensorflow checkpoint

Parameters:
  • checkpoint_directory – the directory for checkpoint
  • summary_directory – the directory for summary used in Tensorboard
  • provide the optimizer and model (__init__) –
  • save the model (__call__) –

Example

transformer = SpeechTransformer(target_vocab_size=dataset_builder.target_dim) optimizer = tf.keras.optimizers.Adam() ckpt = Checkpoint(checkpoint_directory=’./train’, summary_directory=’./event’,

transformer=transformer, optimizer=optimizer)

solver = BaseSolver(transformer) for epoch in dataset:

ckpt()
_compare_and_save_best(self, loss, metrics, save_path)

compare and save the best model with best_loss and N best metrics

compute_nbest_avg(self, model_avg_num)

restore n-best avg checkpoint

__call__(self, loss=None, metrics=None)
restore_from_best(self)

restore from the best model