athena.models.speech_transformer¶
speech transformer implementation
Module Contents¶
Classes¶
SpeechTransformer |
Standard implementation of a SpeechTransformer. Model mainly consists of three parts: |
SpeechTransformer2 |
Decoder for SpeechTransformer2 works for two pass schedual sampling |
-
class
athena.models.speech_transformer.SpeechTransformer(data_descriptions, config=None)¶ Bases:
athena.models.base.BaseModelStandard implementation of a SpeechTransformer. Model mainly consists of three parts: the x_net for input preparation, the y_net for output preparation and the transformer itself
-
default_config¶
-
call(self, samples, training: bool = None)¶ call model
-
static
_create_masks(x, input_length, y)¶ Generate a square mask for the sequence. The masked positions are filled with float(1.0). Unmasked positions are filled with float(0.0).
-
compute_logit_length(self, samples)¶ used for get logit length
-
time_propagate(self, history_logits, history_predictions, step, enc_outputs)¶ TODO: doctring last_predictions: the predictions of last time_step, [beam_size] history_predictions: the predictions of history from 0 to time_step,
[beam_size, time_steps]states: (step)
-
decode(self, samples, hparams, decoder, return_encoder=False)¶ beam search decoding :param samples: the data source to be decoded :param hparams: decoding configs are included here :param decoder: it contains the main decoding operations :param return_encoder: if it is True,
encoder_output and input_mask will be returnedReturns: - the corresponding decoding results
- shape: [batch_size, seq_length] it will be returned only if return_encoder is False
- encoder_output: the encoder output computed in decode mode
- shape: [batch_size, seq_length, hsize]
- input_mask: it is masked by input length
- shape: [batch_size, 1, 1, seq_length] encoder_output and input_mask will be returned only if return_encoder is True
Return type: predictions
-
restore_from_pretrained_model(self, pretrained_model, model_type='')¶ restore from pretrained model
-
deploy(self)¶ deployment function
-
inference_one_step(self, enc_outputs, cur_input, inner_packed_states_array)¶ call back function for WFST decoder
Parameters: - enc_outputs – outputs and mask of encoder
- cur_input – input sequence for transformer, type: list
- inner_packed_states_array – inner states need to be record, type: tuple
Returns: log scores for all labels inner_packed_states_array: inner states for next iterator
Return type: scores
-
-
class
athena.models.speech_transformer.SpeechTransformer2(data_descriptions, config=None)¶ Bases:
athena.models.speech_transformer.SpeechTransformerDecoder for SpeechTransformer2 works for two pass schedual sampling
-
call(self, samples, training: bool = None)¶ call model
-
mix_target_sequence(self, gold_token, predicted_token, training, top_k=5)¶ to mix gold token and prediction param gold_token: true labels param predicted_token: predictions by first pass return: mix of the gold_token and predicted_token
-