speechbrain.utils.autocast module

This module implements utilities and abstractions for use with torch.autocast, i.e. Automatic Mixed Precision.

Authors
  • Sylvain de Langen 2023

Summary

Functions:

fwd_default_precision

Decorator for forward methods which, by default, disables autocast and casts any floating-point tensor parameters into the specified dtype (much like torch.cuda.amp.custom_fwd).

Reference

speechbrain.utils.autocast.fwd_default_precision(fwd: Callable | None = None, cast_inputs: dtype | None = torch.float32)[source]

Decorator for forward methods which, by default, disables autocast and casts any floating-point tensor parameters into the specified dtype (much like torch.cuda.amp.custom_fwd).

The wrapped forward will gain an additional force_allow_autocast keyword parameter. When set to True, the function will ignore cast_inputs and will not disable autocast, as if this decorator was not specified. (Thus, modules can specify a default recommended precision, and users can override that behavior when desired.)

Note that as of PyTorch 2.1.1, this will only affect CUDA AMP. Non-CUDA AMP will be unaffected and no input tensors will be cast! This usecase may be supported by this function in the future.

When autocast is not active, this decorator does not change any behavior.

Parameters:
  • fwd (Optional[Callable]) –

    The function to wrap. If omitted, returns a partial application of the decorator, e.g. allowing new_decorator = fwd_default_precision(cast_inputs=torch.float32).

    Reminder: If you are decorating a function directly, this argument is already specified implicitly.

  • cast_inputs (Optional[torch.dtype]) –

    If not None (the default being torch.float32), then any floating-point inputs to the wrapped function will be cast to the specified type.

    Note: When autocasting is enabled, output tensors of autocast-compatible operations may be of the autocast data type. Disabling autocast without casting inputs will not change this fact, so lower precision operations can happen even inside of an autocast-disabled region, which this argument helps avoid if desired.

Return type:

The wrapped function