athena.utils.checkpoint
¶
checkpoint manager
Module Contents¶
Classes¶
A wrapper for Tensorflow checkpoint |
- class athena.utils.checkpoint.Checkpoint(checkpoint_directory=None, use_dev_loss=True, model=None, **kwargs)¶
Bases:
tensorflow.train.Checkpoint
A wrapper for Tensorflow checkpoint
- Parameters
checkpoint_directory – the directory for checkpoint
summary_directory – the directory for summary used in Tensorboard
__init__ – provide the optimizer and model
__call__ – save the model
Example
>>> transformer = SpeechTransformer(target_vocab_size=dataset_builder.target_dim) >>> optimizer = tf.keras.optimizers.Adam() >>> ckpt = Checkpoint(checkpoint_directory='./train', summary_directory='./event', >>> transformer=transformer, optimizer=optimizer) >>> solver = BaseSolver(transformer) >>> for epoch in dataset: >>> ckpt()
- _file_compatible(use_dev_loss)¶
Convert n_best file to CSV file
Add “index” and “Accuracy” for no csv n_best file.
- _compare_and_save_best(loss, metrics, save_path, training=False)¶
compare and save the best model with best_loss and N best metrics
- compute_nbest_avg(model_avg_num, sort_by=None, sort_by_time=False, reverse=True)¶
Restore n-best avg checkpoint,
if ‘sort_by_time’ is False, the n-best order is sorted by ‘sort_by’; If ‘sort_by_time’ is True, select the newest few models; If ‘reverse’ is True, select the largest models in the sorted order;
- __call__(loss=None, metrics=None, training=False)¶
- restore_from_best()¶
restore from the best model