athena.utils.checkpoint

checkpoint manager

Module Contents

Classes

Checkpoint

A wrapper for Tensorflow checkpoint

class athena.utils.checkpoint.Checkpoint(checkpoint_directory=None, use_dev_loss=True, model=None, **kwargs)

Bases: tensorflow.train.Checkpoint

A wrapper for Tensorflow checkpoint

Parameters
  • checkpoint_directory – the directory for checkpoint

  • summary_directory – the directory for summary used in Tensorboard

  • __init__ – provide the optimizer and model

  • __call__ – save the model

Example

>>> transformer = SpeechTransformer(target_vocab_size=dataset_builder.target_dim)
>>> optimizer = tf.keras.optimizers.Adam()
>>> ckpt = Checkpoint(checkpoint_directory='./train', summary_directory='./event',
>>>        transformer=transformer, optimizer=optimizer)
>>> solver = BaseSolver(transformer)
>>> for epoch in dataset:
>>>    ckpt()
_file_compatible(use_dev_loss)

Convert n_best file to CSV file

Add “index” and “Accuracy” for no csv n_best file.

_compare_and_save_best(loss, metrics, save_path, training=False)

compare and save the best model with best_loss and N best metrics

compute_nbest_avg(model_avg_num, sort_by=None, sort_by_time=False, reverse=True)

Restore n-best avg checkpoint,

if ‘sort_by_time’ is False, the n-best order is sorted by ‘sort_by’; If ‘sort_by_time’ is True, select the newest few models; If ‘reverse’ is True, select the largest models in the sorted order;

__call__(loss=None, metrics=None, training=False)
restore_from_best()

restore from the best model