speechbrain.lobes.models.PIQ module

This file implements the necessary classes and functions to implement Posthoc Interpretations via Quantization.

Authors * Cem Subakan 2023 * Francesco Paissan 2023

Summary

Classes:

Conv2dEncoder_v2

This class implements a convolutional encoder to extract classification embeddings from logspectra.

ResBlockAudio

This class implements a residual block.

VQEmbedding

Implements VQ Dictionary.

VectorQuantization

This class defines the forward method for vector quantization.

VectorQuantizationStraightThrough

This class defines the forward method for vector quantization.

VectorQuantizedPSIFocalNet_Audio

This class reconstructs log-power spectrograms from a FocalNet classifier's representations.

VectorQuantizedPSIViT_Audio

This class reconstructs log-power spectrograms from a ViT classifier's representations.

VectorQuantizedPSI_Audio

This class reconstructs log-power spectrograms from classifier's representations.

Functions:

get_irrelevant_regions

This class returns binary matrix that indicates the irrelevant regions in the VQ-dictionary given the labels array

weights_init

Applies Xavier initialization to network weights.

Reference

speechbrain.lobes.models.PIQ.get_irrelevant_regions(labels, K, num_classes, N_shared=5, stage='TRAIN')[source]

This class returns binary matrix that indicates the irrelevant regions in the VQ-dictionary given the labels array

Parameters:
  • labels (torch.Tensor) – 1 dimensional tensor of size [B]

  • K (int) – Number of keys in the dictionary

  • num_classes (int) – Number of possible classes

  • N_shared (int) – Number of shared keys

  • stage (str) – “TRAIN” or else

Returns:

irrelevant_regions

Return type:

torch.Tensor

Example

>>> labels = torch.Tensor([1, 0, 2])
>>> irrelevant_regions = get_irrelevant_regions(labels, 20, 3, 5)
>>> print(irrelevant_regions.shape)
torch.Size([3, 20])
speechbrain.lobes.models.PIQ.weights_init(m)[source]

Applies Xavier initialization to network weights.

class speechbrain.lobes.models.PIQ.VectorQuantization(*args, **kwargs)[source]

Bases: Function

This class defines the forward method for vector quantization. As VQ is not differentiable, it returns a RuntimeError in case .grad() is called. Refer to VectorQuantizationStraightThrough for a straight_through estimation of the gradient for the VQ operation.

static forward(ctx, inputs, codebook, labels=None, num_classes=10, activate_class_partitioning=True, shared_keys=10, training=True)[source]

Applies VQ to vectors input with codebook as VQ dictionary.

Parameters:
  • ctx (torch context) – The context object for storing info for backwards.

  • inputs (torch.Tensor) – Hidden representations to quantize. Expected shape is torch.Size([B, W, H, C]).

  • codebook (torch.Tensor) – VQ-dictionary for quantization. Expected shape of torch.Size([K, C]) with K dictionary elements.

  • labels (torch.Tensor) – Classification labels. Used to define irrelevant regions and divide the latent space based on predicted class. Shape should be torch.Size([B]).

  • num_classes (int) – Number of possible classes

  • activate_class_partitioning (bool) – True if latent space should be quantized for different classes.

  • shared_keys (int) – Number of shared keys among classes.

  • training (bool) – True if stage is TRAIN.

Returns:

Codebook’s indices for quantized representation

Return type:

torch.Tensor

Example

>>> inputs = torch.ones(3, 14, 25, 256)
>>> codebook = torch.randn(1024, 256)
>>> labels = torch.Tensor([1, 0, 2])
>>> print(VectorQuantization.apply(inputs, codebook, labels).shape)
torch.Size([3, 14, 25])
static backward(ctx, grad_output)[source]

Handles error in case grad() is called on the VQ operation.

class speechbrain.lobes.models.PIQ.VectorQuantizationStraightThrough(*args, **kwargs)[source]

Bases: Function

This class defines the forward method for vector quantization. As VQ is not differentiable, it approximates the gradient of the VQ as in https://arxiv.org/abs/1711.00937.

static forward(ctx, inputs, codebook, labels=None, num_classes=10, activate_class_partitioning=True, shared_keys=10, training=True)[source]

Applies VQ to vectors input with codebook as VQ dictionary and estimates gradients with a Straight-Through (id) approximation of the quantization steps.

Parameters:
  • ctx (torch context) – The context object for storing info for backwards.

  • inputs (torch.Tensor) – Hidden representations to quantize. Expected shape is torch.Size([B, W, H, C]).

  • codebook (torch.Tensor) – VQ-dictionary for quantization. Expected shape of torch.Size([K, C]) with K dictionary elements.

  • labels (torch.Tensor) – Classification labels. Used to define irrelevant regions and divide the latent space based on predicted class. Shape should be torch.Size([B]).

  • num_classes (int) – Number of possible classes

  • activate_class_partitioning (bool) – True if latent space should be quantized for different classes.

  • shared_keys (int) – Number of shared keys among classes.

  • training (bool) – True if stage is TRAIN.

Returns:

Quantized representation and codebook’s indices for quantized representation

Return type:

tuple

Example

>>> inputs = torch.ones(3, 14, 25, 256)
>>> codebook = torch.randn(1024, 256)
>>> labels = torch.Tensor([1, 0, 2])
>>> quant, quant_ind = VectorQuantizationStraightThrough.apply(inputs, codebook, labels)
>>> print(quant.shape, quant_ind.shape)
torch.Size([3, 14, 25, 256]) torch.Size([1050])
static backward(ctx, grad_output, grad_indices, labels=None, num_classes=None, activate_class_partitioning=True, shared_keys=10, training=True)[source]

Estimates gradient assuming vector quantization as identity function. (https://arxiv.org/abs/1711.00937)

class speechbrain.lobes.models.PIQ.Conv2dEncoder_v2(dim=256)[source]

Bases: Module

This class implements a convolutional encoder to extract classification embeddings from logspectra.

Parameters:

dim (int) – Number of channels of the extracted embeddings.

Example

>>> inputs = torch.ones(3, 431, 513)
>>> model = Conv2dEncoder_v2()
>>> print(model(inputs).shape)
torch.Size([3, 256, 26, 32])
forward(x)[source]

Computes forward pass.

Parameters:

x (torch.Tensor) – Log-power spectrogram. Expected shape torch.Size([B, T, F]).

Returns:

Embeddings

Return type:

torch.Tensor

class speechbrain.lobes.models.PIQ.ResBlockAudio(dim)[source]

Bases: Module

This class implements a residual block.

Parameters:

dim (int) – Input channels of the tensor to process. Matches output channels of the residual block.

Example

>>> res = ResBlockAudio(128)
>>> x = torch.randn(2, 128, 16, 16)
>>> print(x.shape)
torch.Size([2, 128, 16, 16])
forward(x)[source]

Forward step.

Parameters:

x (torch.Tensor) – Tensor to process. Expected shape is torch.Size([B, C, H, W]).

Returns:

Residual block output

Return type:

torch.Tensor

class speechbrain.lobes.models.PIQ.VectorQuantizedPSI_Audio(dim=128, K=512, numclasses=50, activate_class_partitioning=True, shared_keys=0, use_adapter=True, adapter_reduce_dim=True)[source]

Bases: Module

This class reconstructs log-power spectrograms from classifier’s representations.

Parameters:
  • dim (int) – Dimensionality of VQ vectors.

  • K (int) – Number of elements of VQ dictionary.

  • numclasses (int) – Number of possible classes

  • activate_class_partitioning (bool) – True if latent space should be quantized for different classes.

  • shared_keys (int) – Number of shared keys among classes.

  • use_adapter (bool) – True to learn an adapter for classifier’s representations.

  • adapter_reduce_dim (bool) – True if adapter should compress representations.

Example

>>> psi = VectorQuantizedPSI_Audio(dim=256, K=1024)
>>> x = torch.randn(2, 256, 16, 16)
>>> labels = torch.Tensor([0, 2])
>>> logspectra, hcat, z_q_x = psi(x, labels)
>>> print(logspectra.shape, hcat.shape, z_q_x.shape)
torch.Size([2, 1, 257, 257]) torch.Size([2, 256, 8, 8]) torch.Size([2, 256, 8, 8])
forward(hs, labels)[source]

Forward step. Reconstructs log-power based on provided label’s keys in VQ dictionary.

Parameters:
  • hs (torch.Tensor) – Classifier’s representations.

  • labels (torch.Tensor) – Predicted labels for classifier’s representations.

Returns:

Reconstructed log-power spectrogram, reduced classifier’s representations and quantized classifier’s representations.

Return type:

tuple

class speechbrain.lobes.models.PIQ.VectorQuantizedPSIFocalNet_Audio(dim=1024, **kwargs)[source]

Bases: VectorQuantizedPSI_Audio

This class reconstructs log-power spectrograms from a FocalNet classifier’s representations.

Parameters:
  • dim (int) – Dimensionality of VQ vectors.

  • kwargs (dict) – See documentation of VectorQuantizedPSI_Audio.

Example

>>> psi = VectorQuantizedPSIFocalNet_Audio(dim=256, K=1024)
>>> x = torch.randn(2, 256, 16, 16)
>>> labels = torch.Tensor([0, 2])
>>> logspectra, hcat, z_q_x = psi(x, labels)
>>> print(logspectra.shape, hcat.shape, z_q_x.shape)
torch.Size([2, 1, 495, 593]) torch.Size([2, 256, 8, 8]) torch.Size([2, 256, 8, 8])
class speechbrain.lobes.models.PIQ.VectorQuantizedPSIViT_Audio(dim=768, **kwargs)[source]

Bases: VectorQuantizedPSI_Audio

This class reconstructs log-power spectrograms from a ViT classifier’s representations.

Parameters:
  • dim (int) – Dimensionality of VQ vectors.

  • kwargs (dict) – See documentation of VectorQuantizedPSI_Audio.

Example

>>> psi = VectorQuantizedPSIViT_Audio(dim=256, K=1024)
>>> x = torch.randn(2, 256, 16, 16)
>>> labels = torch.Tensor([0, 2])
>>> logspectra, hcat, z_q_x = psi(x, labels)
>>> print(logspectra.shape, hcat.shape, z_q_x.shape)
torch.Size([2, 1, 495, 593]) torch.Size([2, 256, 8, 8]) torch.Size([2, 256, 8, 8])
class speechbrain.lobes.models.PIQ.VQEmbedding(K, D, numclasses=50, activate_class_partitioning=True, shared_keys=0)[source]

Bases: Module

Implements VQ Dictionary. Wraps VectorQuantization and VectorQuantizationStraightThrough. For more details refer to the specific class.

Parameters:
  • K (int) – Number of elements of VQ dictionary.

  • D (int) – Dimensionality of VQ vectors.

  • numclasses (int) – Number of possible classes

  • activate_class_partitioning (bool) – True if latent space should be quantized for different classes.

  • shared_keys (int) – Number of shared keys among classes.

forward(z_e_x, labels=None)[source]

Wraps VectorQuantization. Computes VQ-dictionary indices for input quantization. Note that this forward step is not differentiable.

Parameters:
  • z_e_x (torch.Tensor) – Input tensor to be quantized.

  • labels (torch.Tensor) – Predicted class for input representations (used for latent space quantization).

Returns:

Codebook’s indices for quantized representation

Return type:

torch.Tensor

Example

>>> inputs = torch.ones(3, 256, 14, 25)
>>> codebook = VQEmbedding(1024, 256)
>>> labels = torch.Tensor([1, 0, 2])
>>> print(codebook(inputs, labels).shape)
torch.Size([3, 14, 25])
straight_through(z_e_x, labels=None)[source]

Implements the vector quantization with straight through approximation of the gradient.

Parameters:
  • z_e_x (torch.Tensor) – Input tensor to be quantized.

  • labels (torch.Tensor) – Predicted class for input representations (used for latent space quantization).

Returns:

Straight through quantized representation and quantized representation

Return type:

tuple

Example

>>> inputs = torch.ones(3, 256, 14, 25)
>>> codebook = VQEmbedding(1024, 256)
>>> labels = torch.Tensor([1, 0, 2])
>>> quant, quant_ind = codebook.straight_through(inputs, labels)
>>> print(quant.shape, quant_ind.shape)
torch.Size([3, 256, 14, 25]) torch.Size([3, 256, 14, 25])