athena.utils.misc

misc

Module Contents

Functions

mask_index_from_labels(labels, index)
insert_sos_in_labels(labels, sos)
remove_eos_in_labels(input_labels, labels_length) remove eos in labels, batch size should be larger than 1
insert_eos_in_labels(input_labels, eos, labels_length) insert eos in labels, batch size should be larger than 1
generate_square_subsequent_mask(size) Generate a square mask for the sequence. The masked positions are filled with float(1.0).
validate_seqs(seqs, eos) Discard end symbol and elements after end symbol
get_wave_file_length(wave_file) get the wave file length(duration) in ms
splice_numpy(x, context) Splice a tensor along the last dimension with context.
set_default_summary_writer(summary_directory=None)
tensor_shape(tensor) Return a list with tensor shape. For each dimension,
athena.utils.misc.mask_index_from_labels(labels, index)
athena.utils.misc.insert_sos_in_labels(labels, sos)
athena.utils.misc.remove_eos_in_labels(input_labels, labels_length)

remove eos in labels, batch size should be larger than 1 assuming 0 as the padding and the last one is the eos

athena.utils.misc.insert_eos_in_labels(input_labels, eos, labels_length)

insert eos in labels, batch size should be larger than 1 assuming 0 as the padding,

athena.utils.misc.generate_square_subsequent_mask(size)

Generate a square mask for the sequence. The masked positions are filled with float(1.0). Unmasked positions are filled with float(0.0).

athena.utils.misc.validate_seqs(seqs, eos)

Discard end symbol and elements after end symbol :param seqs: tf.Tensor shape=(batch_size, seq_length)

Returns:tf.SparseTensor
Return type:validated_preds
athena.utils.misc.get_wave_file_length(wave_file)

get the wave file length(duration) in ms

Parameters:wave_file – the path of wave file
Returns:the length(ms) of the wave file
athena.utils.misc.splice_numpy(x, context)

Splice a tensor along the last dimension with context. e.g.: t = [[[1, 2, 3],

[4, 5, 6], [7, 8, 9]]]
splice_tensor(t, [0, 1]) =
[[[1, 2, 3, 4, 5, 6],
[4, 5, 6, 7, 8, 9], [7, 8, 9, 7, 8, 9]]]
Parameters:
  • tensor – a tf.Tensor with shape (B, T, D) a.k.a. (N, H, W)
  • context – a list of context offsets
Returns:

spliced tensor with shape (…, D * len(context))

athena.utils.misc.set_default_summary_writer(summary_directory=None)
athena.utils.misc.tensor_shape(tensor)

Return a list with tensor shape. For each dimension, use tensor.get_shape() first. If not available, use tf.shape().