speechbrain.nnet.attention module

Library implementing attention modules.

Authors
  • Ju-Chieh Chou 2020

  • Jianyuan Zhong 2020

  • Loren Lugosch 2020

  • Samuele Cornell 2020

Summary

Classes:

ContentBasedAttention

This class implements content-based attention module for seq2seq learning.

KeyValueAttention

This class implements a single-headed key-value attention module for seq2seq learning.

LocationAwareAttention

This class implements location-aware attention module for seq2seq learning.

MultiheadAttention

The class is a wrapper of MultiHead Attention for torch.nn.MultiHeadAttention.

PositionalwiseFeedForward

The class implements the positional-wise feed forward module in “Attention Is All You Need”.

RelPosEncXL

Relative positional encoding for the RelPosMHAXL.

RelPosMHAXL

This class implements the relative multihead implementation similar to that in Transformer XL https://arxiv.org/pdf/1901.02860.pdf

Reference

class speechbrain.nnet.attention.ContentBasedAttention(enc_dim, dec_dim, attn_dim, output_dim, scaling=1.0)[source]

Bases: Module

This class implements content-based attention module for seq2seq learning.

Reference: NEURAL MACHINE TRANSLATION BY JOINTLY LEARNING TO ALIGN AND TRANSLATE, Bahdanau et.al. https://arxiv.org/pdf/1409.0473.pdf

Parameters:
  • enc_dim (int) – Size of encoder layer.

  • dec_dim (int) – Size of decoder layer.

  • attn_dim (int) – Size of the attention feature.

  • output_dim (int) – Size of the output context vector.

  • scaling (float) – The factor controls the sharpening degree (default: 1.0).

Example

>>> enc_tensor = torch.rand([4, 10, 20])
>>> enc_len = torch.ones([4]) * 10
>>> dec_tensor = torch.rand([4, 25])
>>> net = ContentBasedAttention(enc_dim=20, dec_dim=25, attn_dim=30, output_dim=5)
>>> out_tensor, out_weight = net(enc_tensor, enc_len, dec_tensor)
>>> out_tensor.shape
torch.Size([4, 5])
reset()[source]

Reset the memory in the attention module.

forward(enc_states, enc_len, dec_states)[source]

Returns the output of the attention module.

Parameters:
  • enc_states (torch.Tensor) – The tensor to be attended.

  • enc_len (torch.Tensor) – The real length (without padding) of enc_states for each sentence.

  • dec_states (torch.Tensor) – The query tensor.

Return type:

The output of the attention module.

class speechbrain.nnet.attention.LocationAwareAttention(enc_dim, dec_dim, attn_dim, output_dim, conv_channels, kernel_size, scaling=1.0)[source]

Bases: Module

This class implements location-aware attention module for seq2seq learning.

Reference: Attention-Based Models for Speech Recognition, Chorowski et.al. https://arxiv.org/pdf/1506.07503.pdf

Parameters:
  • enc_dim (int) – Size of encoder.

  • dec_dim (int) – Size of decoder.

  • attn_dim (int) – Size of the attention feature.

  • output_dim (int) – Size of the output context vector.

  • conv_channels (int) – Number of channel for location feature.

  • kernel_size (int) – Kernel size of convolutional layer for location feature.

  • scaling (float) – The factor controls the sharpening degree (default: 1.0).

Example

>>> enc_tensor = torch.rand([4, 10, 20])
>>> enc_len = torch.ones([4]) * 10
>>> dec_tensor = torch.rand([4, 25])
>>> net = LocationAwareAttention(
...     enc_dim=20,
...     dec_dim=25,
...     attn_dim=30,
...     output_dim=5,
...     conv_channels=10,
...     kernel_size=100)
>>> out_tensor, out_weight = net(enc_tensor, enc_len, dec_tensor)
>>> out_tensor.shape
torch.Size([4, 5])
precomputed_enc_h: Tensor | None
reset()[source]

Reset the memory in attention module.

forward(enc_states, enc_len, dec_states)[source]

Returns the output of the attention module.

Parameters:
  • enc_states (torch.Tensor) – The tensor to be attended.

  • enc_len (torch.Tensor) – The real length (without padding) of enc_states for each sentence.

  • dec_states (torch.Tensor) – The query tensor.

Return type:

The output of the attention module.

class speechbrain.nnet.attention.KeyValueAttention(enc_dim, dec_dim, attn_dim, output_dim)[source]

Bases: Module

This class implements a single-headed key-value attention module for seq2seq learning.

Reference: “Attention Is All You Need” by Vaswani et al., sec. 3.2.1

Parameters:
  • enc_dim (int) – Size of the encoder feature vectors from which keys and values are computed.

  • dec_dim (int) – Size of the decoder feature vectors from which queries are computed.

  • attn_dim (int) – Size of the attention feature.

  • output_dim (int) – Size of the output context vector.

Example

>>> enc_tensor = torch.rand([4, 10, 20])
>>> enc_len = torch.ones([4]) * 10
>>> dec_tensor = torch.rand([4, 25])
>>> net = KeyValueAttention(enc_dim=20, dec_dim=25, attn_dim=30, output_dim=5)
>>> out_tensor, out_weight = net(enc_tensor, enc_len, dec_tensor)
>>> out_tensor.shape
torch.Size([4, 5])
reset()[source]

Reset the memory in the attention module.

forward(enc_states, enc_len, dec_states)[source]

Returns the output of the attention module.

Parameters:
  • enc_states (torch.Tensor) – The tensor to be attended.

  • enc_len (torch.Tensor) – The real length (without padding) of enc_states for each sentence.

  • dec_states (torch.Tensor) – The query tensor.

Return type:

The output of the attention module.

class speechbrain.nnet.attention.RelPosEncXL(emb_dim: int, dtype: dtype = torch.float32)[source]

Bases: Module

Relative positional encoding for the RelPosMHAXL.

Parameters:
  • emb_dim (int) – Size of the embedding, which controls the size of the last dimension of the positional embedding as well

  • dtype (torch.dtype, optional) – If unspecified, defaults to torch.float32. Controls the data type of the output embedding (but does not affect the precision of the computations, which remain torch.float32).

make_pe(seq_len: int)[source]

Builds the positional embedding tensor for a given sequence length.

Parameters:

seq_len (int) – The length of the sequence to create the position embedding for.

Returns:

Positional embedding tensor of shape [1, 2*seq_len-1, embed_dim]

Return type:

torch.Tensor

forward(x: Tensor)[source]

Builds the positional embedding tensor. Similar to make_pe() but uses the shape information from the provided tensor.

Parameters:

x (torch.Tensor) – input tensor with shape batch_size, seq_len, embed_dim

Returns:

pos_emb – Positional embedding tensor of shape [1, 2*seq_len-1, embed_dim]

Return type:

torch.Tensor

class speechbrain.nnet.attention.RelPosMHAXL(embed_dim, num_heads, dropout=0.0, vbias=False, vdim=None, mask_pos_future=False)[source]

Bases: Module

This class implements the relative multihead implementation similar to that in Transformer XL https://arxiv.org/pdf/1901.02860.pdf

Parameters:
  • embed_dim (int) – Size of the encoder feature vectors from which keys and values are computed.

  • num_heads (int) – Number of attention heads.

  • dropout (float, optional) – Dropout rate.

  • vbias (bool, optional) – Whether to use bias for computing value.

  • vdim (int, optional) – Size for value. Default is embed_dim (Note each head is embed_dim // num_heads).

  • mask_pos_future (bool, optional) – Whether to mask future positional encodings values. Must be true for causal applications e.g. decoder.

Example

>>> inputs = torch.rand([6, 60, 512])
>>> pos_emb = torch.rand([1, 2*60-1, 512])
>>> net = RelPosMHAXL(num_heads=8, embed_dim=inputs.shape[-1])
>>> outputs, attn = net(inputs, inputs, inputs, pos_emb)
>>> outputs.shape
torch.Size([6, 60, 512])
rel_shift(x)[source]

Relative shift implementation.

forward(query, key, value, pos_embs, key_padding_mask=None, attn_mask=None, return_attn_weights=True)[source]

Compute attention.

Parameters:
  • query (torch.Tensor) – (B, L, E) where L is the target sequence length, B is the batch size, E is the embedding dimension.

  • key (torch.Tensor) – (B, S, E) where S is the source sequence length, B is the batch size, E is the embedding dimension.

  • value (torch.Tensor) – (B, S, E) where S is the source sequence length, B is the batch size, E is the embedding dimension.

  • pos_embs (torch.Tensor) – bidirectional sinusoidal positional embedding tensor (1, 2*S-1, E) where S is the max length between source and target sequence lengths, and E is the embedding dimension.

  • key_padding_mask (torch.Tensor) – (B, S) where B is the batch size, S is the source sequence length. If a ByteTensor is provided, the non-zero positions will be ignored while the position with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the value of True will be ignored while the position with the value of False will be unchanged.

  • attn_mask (torch.Tensor) – 2D mask (L, S) where L is the target sequence length, S is the source sequence length. 3D mask (N*num_heads, L, S) where N is the batch size, L is the target sequence length, S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend while the zero positions will be unchanged. If a BoolTensor is provided, positions with True is not allowed to attend while False values will be unchanged. If a FloatTensor is provided, it will be added to the attention weight.

  • return_attn_weights (bool) – Whether to additionally return the attention weights.

Returns:

  • out (torch.Tensor) – (B, L, E) where L is the target sequence length, B is the batch size, E is the embedding dimension.

  • attn_score (torch.Tensor) – (B, L, S) where B is the batch size, L is the target sequence length, S is the source sequence length.

class speechbrain.nnet.attention.MultiheadAttention(nhead, d_model, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None)[source]

Bases: Module

The class is a wrapper of MultiHead Attention for torch.nn.MultiHeadAttention.

Reference: https://pytorch.org/docs/stable/nn.html

Parameters:
  • nhead (int) – parallel attention heads.

  • d_model (int) – The size of the model layers.

  • dropout (float) – a Dropout layer on attn_output_weights (default: 0.0).

  • bias (bool) – add bias as module parameter (default: True).

  • add_bias_kv (bool) – add bias to the key and value sequences at dim=0.

  • add_zero_attn (bool) – add a new batch of zeros to the key and value sequences at dim=1.

  • kdim (int) – total number of features in key (default: None).

  • vdim (int) – total number of features in value (default: None).

Example

>>> inputs = torch.rand([8, 60, 512])
>>> net = MultiheadAttention(nhead=8, d_model=inputs.shape[-1])
>>> outputs, attn = net(inputs, inputs, inputs)
>>> outputs.shape
torch.Size([8, 60, 512])
forward(query, key, value, attn_mask: Tensor | None = None, key_padding_mask: Tensor | None = None, return_attn_weights: bool = True, pos_embs: Tensor | None = None)[source]

Compute attention.

Parameters:
  • query (torch.Tensor) – (B, L, E) where L is the target sequence length, B is the batch size, E is the embedding dimension.

  • key (torch.Tensor) – (B, S, E) where S is the source sequence length, B is the batch size, E is the embedding dimension.

  • value (torch.Tensor) – (B, S, E) where S is the source sequence length, B is the batch size, E is the embedding dimension.

  • attn_mask (torch.Tensor, optional) – 2D mask (L, S) where L is the target sequence length, S is the source sequence length. 3D mask (N*num_heads, L, S) where N is the batch size, L is the target sequence length, S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend while the zero positions will be unchanged. If a BoolTensor is provided, positions with True is not allowed to attend while False values will be unchanged. If a FloatTensor is provided, it will be added to the attention weight.

  • key_padding_mask (torch.Tensor, optional) – (B, S) where B is the batch size, S is the source sequence length. If a ByteTensor is provided, the non-zero positions will be ignored while the position with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the value of True will be ignored while the position with the value of False will be unchanged.

  • return_attn_weights (bool, optional) – True to additionally return the attention weights, False otherwise.

  • pos_embs (torch.Tensor, optional) – Positional embeddings added to the attention map of shape (L, S, E) or (L, S, 1).

Returns:

  • attn_output (torch.Tensor) – (B, L, E) where L is the target sequence length, B is the batch size, E is the embedding dimension.

  • attn_output_weights (torch.Tensor) – (B, L, S) where B is the batch size, L is the target sequence length, S is the source sequence length. This is returned only if return_attn_weights=True (True by default).

class speechbrain.nnet.attention.PositionalwiseFeedForward(d_ffn, input_shape=None, input_size=None, dropout=0.0, activation=<class 'torch.nn.modules.activation.ReLU'>)[source]

Bases: Module

The class implements the positional-wise feed forward module in “Attention Is All You Need”.

Parameters:
  • d_ffn (int) – Hidden layer size.

  • input_shape (tuple, optional) – Expected shape of the input. Alternatively use input_size.

  • input_size (int, optional) – Expected size of the input. Alternatively use input_shape.

  • dropout (float, optional) – Dropout rate.

  • activation (torch.nn.Module, optional) – activation functions to be applied (Recommendation: ReLU, GELU).

Example

>>> inputs = torch.rand([8, 60, 512])
>>> net = PositionalwiseFeedForward(256, input_size=inputs.shape[-1])
>>> outputs = net(inputs)
>>> outputs.shape
torch.Size([8, 60, 512])
forward(x)[source]

Applies PositionalwiseFeedForward to the input tensor x.