athena.tools.beam_search¶
the beam search decoder layer in encoder-decoder models
Module Contents¶
Classes¶
BeamSearchDecoder |
Beam search decoding used in seq2seq decoder layer |
-
athena.tools.beam_search.CandidateHolder¶
-
class
athena.tools.beam_search.BeamSearchDecoder(num_class, sos, eos, beam_size)¶ Beam search decoding used in seq2seq decoder layer This layer is used for evaluation
-
static
build_decoder(hparams, num_class, sos, eos, decoder_one_step, lm_model=None)¶ Allocate the time propagating function of the decoder, initialize the decoder
Parameters: - hparams – the decoding configs are included here
- num_class – the size of the vocab
- sos – the start symbol index
- eos – the end symbol index
- decoder_one_step – the time propagating function of the decoder
- lm_model – the initialized languange model
Returns: the initialized beam search decoder
Return type: beam_search_decoder
-
set_lm_model(self, lm_model)¶ set the lm_model :param lm_model: lm_model
-
set_ctc_scorer(self, ctc_scorer)¶ set the ctc_scorer :param ctc_scorer: the ctc scorer
-
beam_search_score(self, candidate_holder, encoder_outputs)¶ Call the time propagating function, fetch the acoustic score at the current step
If needed, call the auxiliary scorer and update cand_states in candidate_holder
Parameters: - candidate_holder – the param cand_seqs and the cand_logits of it is needed in the transformer decoder to calculate the output. type: CandidateHolder
- encoder_outputs – the encoder outputs from the transformer encoder. type: tuple, (encoder_outputs, input_mask)
-
deal_with_completed(self, completed_scores, completed_seqs, completed_length, new_scores, candidate_holder, max_seq_len)¶ - Add the new calculated completed seq with its score to completed seqs
- select top beam_size probable completed seqs with these corresponding scores
Parameters: - completed_scores – the scores of completed_seqs
- completed_seqs – historical top beam_size probable completed seqs
- completed_length – the length of completed_seqs
- new_scores – the current time step scores
- candidate_holder –
- max_seq_len – the maximum acceptable output length
Returns: new top probable scores completed_seqs: new top probable completed seqs completed_length: new top probable seq length
Return type: new_completed_scores
-
deal_with_uncompleted(self, new_scores, new_cand_logits, new_states, candidate_holder)¶ - select top probable candidate seqs from new predictions with its scores
- update candidate_holder based on top probable candidates
Parameters: - new_scores – the current time step prediction scores
- new_cand_logits – historical prediction scores
- new_states – updated states
- candidate_holder –
Returns: - cand_seqs, cand_logits, cand_states,
cand_scores, cand_parents will be updated here and sent to next time step
Return type: candidate_holder
-
__call__(self, cand_seqs, cand_states, init_states, encoder_outputs)¶ Parameters: - cand_seqs – TensorArray list, element shape: [beam]
- cand_states – [history_predictions]
- init_states – state list
- encoder_outputs – (encoder_outputs, memory_mask, …)
Returns: the sequence with highest score
Return type: completed_seqs
-
static