athena.models.base

base model for models

Module Contents

Classes

BaseModel Base class for model.
class athena.models.base.BaseModel(**kwargs)

Bases: tensorflow.keras.Model

Base class for model.

call(self, samples, training=None)

call model

get_loss(self, outputs, samples, training=None)

get loss

compute_logit_length(self, samples)

compute the logit length

reset_metrics(self)

reset the metrics

prepare_samples(self, samples)

for special data prepare carefully: do not change the shape of samples

restore_from_pretrained_model(self, pretrained_model, model_type='')

restore from pretrained model

decode(self, samples, hparams, decoder)

decode interface