athena.layers.commons

Utils for common layers.

Module Contents

Classes

PositionalEncoding positional encoding can be used in transformer
ScaledPositionalEncoding scaled positional encoding,
Collapse4D callapse4d can be used in cnn-lstm for speech processing
Gelu Gaussian Error Linear Unit.
TdnnLayer An implement of Tdnn Layer
ZoneOutCell Wrapper for LSTM cell to create ZoneOut Cell
class athena.layers.commons.PositionalEncoding(d_model, max_position=800, scale=False)

Bases: tensorflow.keras.layers.Layer

positional encoding can be used in transformer

call(self, x)

call function

class athena.layers.commons.ScaledPositionalEncoding(d_model, max_position=800)

Bases: athena.layers.commons.PositionalEncoding

scaled positional encoding, reference: https://arxiv.org/pdf/1809.08895.pdf

build(self, _)
call(self, x)

call function

class athena.layers.commons.Collapse4D

Bases: tensorflow.keras.layers.Layer

callapse4d can be used in cnn-lstm for speech processing reshape from [N T D C] -> [N T D*C]

call(self, x)
class athena.layers.commons.Gelu

Bases: tensorflow.keras.layers.Layer

Gaussian Error Linear Unit. This is a smoother version of the RELU. Original paper: https://arxiv.org/abs/1606.08415 :param x: float Tensor to perform activation.

Returns:x with the GELU activation applied.
call(self, x)
class athena.layers.commons.TdnnLayer(context, output_dim, use_bias=False, **kwargs)

Bases: tensorflow.keras.layers.Layer

An implement of Tdnn Layer :param context: a int of left and right context, or :param a list of context indexes, e.g.: :type a list of context indexes, e.g.: -2, 0, 2 :param output_dim: the dim of the linear transform

call(self, x, training=None, mask=None)
class athena.layers.commons.ZoneOutCell(zoneout_rate=0.0, **kwargs)

Bases: tensorflow.keras.layers.LSTMCell

Wrapper for LSTM cell to create ZoneOut Cell

inspired by: https://github.com/teganmaharaj/zoneout/blob/master/zoneout_tensorflow.py Published by one of ‘https://arxiv.org/pdf/1606.01305.pdf’ paper writers.

call(self, inputs, states, training: bool = None)

Runs vanilla LSTM Cell and applies zoneout.

get_config(self)
athena.layers.commons.SUPPORTED_RNNS
athena.layers.commons.ACTIVATIONS