athena.models.kws.base

base model for models

Module Contents

Classes

BaseModel

Base class for model.

class athena.models.kws.base.BaseModel(**kwargs)

Bases: tensorflow.keras.Model

Base class for model.

abstract call(samples, training=None)

call model

get_loss(outputs, samples, training=None)

get loss

compute_logit_length(samples)

compute the logit length

reset_metrics()

reset the metrics

prepare_samples(samples)

for special data prepare carefully: do not change the shape of samples

restore_from_pretrained_model(pretrained_model, model_type='')

restore from pretrained model

decode(samples, hparams, decoder)

decode interface