athena.models.speech_transformer

speech transformer implementation

Module Contents

Classes

SpeechTransformer Standard implementation of a SpeechTransformer. Model mainly consists of three parts:
SpeechTransformer2 Decoder for SpeechTransformer2 works for two pass schedual sampling
class athena.models.speech_transformer.SpeechTransformer(data_descriptions, config=None)

Bases: athena.models.base.BaseModel

Standard implementation of a SpeechTransformer. Model mainly consists of three parts: the x_net for input preparation, the y_net for output preparation and the transformer itself

default_config
call(self, samples, training: bool = None)

call model

static _create_masks(x, input_length, y)

Generate a square mask for the sequence. The masked positions are filled with float(1.0). Unmasked positions are filled with float(0.0).

compute_logit_length(self, samples)

used for get logit length

time_propagate(self, history_logits, history_predictions, step, enc_outputs)

TODO: doctring last_predictions: the predictions of last time_step, [beam_size] history_predictions: the predictions of history from 0 to time_step,

[beam_size, time_steps]

states: (step)

decode(self, samples, hparams, decoder, return_encoder=False)

beam search decoding :param samples: the data source to be decoded :param hparams: decoding configs are included here :param decoder: it contains the main decoding operations :param return_encoder: if it is True,

encoder_output and input_mask will be returned
Returns:
the corresponding decoding results
shape: [batch_size, seq_length] it will be returned only if return_encoder is False
encoder_output: the encoder output computed in decode mode
shape: [batch_size, seq_length, hsize]
input_mask: it is masked by input length
shape: [batch_size, 1, 1, seq_length] encoder_output and input_mask will be returned only if return_encoder is True
Return type:predictions
restore_from_pretrained_model(self, pretrained_model, model_type='')

restore from pretrained model

deploy(self)

deployment function

inference_one_step(self, enc_outputs, cur_input, inner_packed_states_array)

call back function for WFST decoder

Parameters:
  • enc_outputs – outputs and mask of encoder
  • cur_input – input sequence for transformer, type: list
  • inner_packed_states_array – inner states need to be record, type: tuple
Returns:

log scores for all labels inner_packed_states_array: inner states for next iterator

Return type:

scores

class athena.models.speech_transformer.SpeechTransformer2(data_descriptions, config=None)

Bases: athena.models.speech_transformer.SpeechTransformer

Decoder for SpeechTransformer2 works for two pass schedual sampling

call(self, samples, training: bool = None)

call model

mix_target_sequence(self, gold_token, predicted_token, training, top_k=5)

to mix gold token and prediction param gold_token: true labels param predicted_token: predictions by first pass return: mix of the gold_token and predicted_token