speechbrain.utils.metric_stats module

The metric_stats module provides an abstract class for storing statistics produced over the course of an experiment and summarizing them.

Authors:
  • Peter Plantinga 2020

  • Mirco Ravanelli 2020

  • Gaëlle Laperrière 2021

  • Sahar Ghannay 2021

Summary

Classes:

BinaryMetricStats

Tracks binary metrics, such as precision, recall, F1, EER, etc.

ClassificationStats

Computes statistics pertaining to multi-label classification tasks, as well as tasks that can be loosely interpreted as such for the purpose of evaluations.

EmbeddingErrorRateSimilarity

Implements the similarity function from the EmbER metric as defined by https://www.isca-archive.org/interspeech_2022/roux22_interspeech.pdf

ErrorRateStats

A class for tracking error rates (e.g., WER, PER).

MetricStats

A default class for storing and summarizing arbitrary metrics.

MultiMetricStats

A wrapper that evaluates multiple metrics simultaneously

WeightedErrorRateStats

Metric that reweighs the WER from ErrorRateStats with any chosen method.

Functions:

EER

Computes the EER (and its threshold).

minDCF

Computes the minDCF metric normally used to evaluate speaker verification systems.

multiprocess_evaluation

Runs metric evaluation if parallel over multiple jobs.

sequence_evaluation

Runs metric evaluation sequentially over the inputs.

Reference

class speechbrain.utils.metric_stats.MetricStats(metric, n_jobs=1, batch_eval=True)[source]

Bases: object

A default class for storing and summarizing arbitrary metrics.

More complex metrics can be created by sub-classing this class.

Parameters:
  • metric (function) – The function to use to compute the relevant metric. Should take at least two arguments (predictions and targets) and can optionally take the relative lengths of either or both arguments. Not usually used in sub-classes.

  • n_jobs (int) – The number of jobs to use for computing the metric. If this is more than one, every sample is processed individually, otherwise the whole batch is passed at once.

  • batch_eval (bool) – When True it feeds the evaluation metric with the batched input. When False and n_jobs=1, it performs metric evaluation one-by-one in a sequential way. When False and n_jobs>1, the evaluation runs in parallel over the different inputs using joblib.

Example

>>> from speechbrain.nnet.losses import l1_loss
>>> loss_stats = MetricStats(metric=l1_loss)
>>> loss_stats.append(
...      ids=["utterance1", "utterance2"],
...      predictions=torch.tensor([[0.1, 0.2], [0.2, 0.3]]),
...      targets=torch.tensor([[0.1, 0.2], [0.1, 0.2]]),
...      reduction="batch",
... )
>>> stats = loss_stats.summarize()
>>> stats['average']
0.050...
>>> stats['max_score']
0.100...
>>> stats['max_id']
'utterance2'
clear()[source]

Creates empty container for storage, removing existing stats.

append(ids, *args, **kwargs)[source]

Store a particular set of metric scores.

Parameters:
  • ids (list) – List of ids corresponding to utterances.

  • *args (tuple) – Arguments to pass to the metric function.

  • **kwargs (dict) – Arguments to pass to the metric function.

summarize(field=None)[source]

Summarize the metric scores, returning relevant stats.

Parameters:

field (str) – If provided, only returns selected statistic. If not, returns all computed statistics.

Returns:

Returns a float if field is provided, otherwise returns a dictionary containing all computed stats.

Return type:

float or dict

write_stats(filestream, verbose=False)[source]

Write all relevant statistics to file.

Parameters:
  • filestream (file-like object) – A stream for the stats to be written to.

  • verbose (bool) – Whether to also print the stats to stdout.

speechbrain.utils.metric_stats.multiprocess_evaluation(metric, predict, target, lengths=None, n_jobs=8)[source]

Runs metric evaluation if parallel over multiple jobs.

speechbrain.utils.metric_stats.sequence_evaluation(metric, predict, target, lengths=None)[source]

Runs metric evaluation sequentially over the inputs.

class speechbrain.utils.metric_stats.ErrorRateStats(merge_tokens=False, split_tokens=False, space_token='_', keep_values=True, extract_concepts_values=False, tag_in='', tag_out='', equality_comparator: ~typing.Callable[[str, str], bool] = <function _str_equals>)[source]

Bases: MetricStats

A class for tracking error rates (e.g., WER, PER).

Parameters:
  • merge_tokens (bool) – Whether to merge the successive tokens (used for e.g., creating words out of character tokens). See speechbrain.dataio.dataio.merge_char.

  • split_tokens (bool) – Whether to split tokens (used for e.g. creating characters out of word tokens). See speechbrain.dataio.dataio.split_word.

  • space_token (str) – The character to use for boundaries. Used with merge_tokens this represents character to split on after merge. Used with split_tokens the sequence is joined with this token in between, and then the whole sequence is split.

  • keep_values (bool) – Whether to keep the values of the concepts or not.

  • extract_concepts_values (bool) – Process the predict and target to keep only concepts and values.

  • tag_in (str) – Start of the concept (‘<’ for example).

  • tag_out (str) – End of the concept (‘>’ for example).

  • equality_comparator (Callable[[str, str], bool]) – The function used to check whether two words are equal.

Example

>>> cer_stats = ErrorRateStats()
>>> i2l = {0: 'a', 1: 'b'}
>>> cer_stats.append(
...     ids=['utterance1'],
...     predict=torch.tensor([[0, 1, 1]]),
...     target=torch.tensor([[0, 1, 0]]),
...     target_len=torch.ones(1),
...     ind2lab=lambda batch: [[i2l[int(x)] for x in seq] for seq in batch],
... )
>>> stats = cer_stats.summarize()
>>> stats['WER']
33.33...
>>> stats['insertions']
0
>>> stats['deletions']
0
>>> stats['substitutions']
1
append(ids, predict, target, predict_len=None, target_len=None, ind2lab=None)[source]

Add stats to the relevant containers.

  • See MetricStats.append()

Parameters:
  • ids (list) – List of ids corresponding to utterances.

  • predict (torch.tensor) – A predicted output, for comparison with the target output

  • target (torch.tensor) – The correct reference output, for comparison with the prediction.

  • predict_len (torch.tensor) – The predictions relative lengths, used to undo padding if there is padding present in the predictions.

  • target_len (torch.tensor) – The target outputs’ relative lengths, used to undo padding if there is padding present in the target.

  • ind2lab (callable) – Callable that maps from indices to labels, operating on batches, for writing alignments.

summarize(field=None)[source]

Summarize the error_rate and return relevant statistics.

  • See MetricStats.summarize()

write_stats(filestream)[source]

Write all relevant info (e.g., error rate alignments) to file. * See MetricStats.write_stats()

class speechbrain.utils.metric_stats.WeightedErrorRateStats(base_stats: ErrorRateStats, cost_function: Callable[[str, str | None, str | None], float], weight_name: str = 'weighted')[source]

Bases: MetricStats

Metric that reweighs the WER from ErrorRateStats with any chosen method. This does not edit the sequence of found edits (insertion/deletion/substitution) but multiplies their impact on the metric by a value between 0 and 1 as returned by the cost function.

Parameters:
  • base_stats (ErrorRateStats) – The base WER calculator to use.

  • cost_function (Callable[[str, Optional[str], Optional[str]], float]) – Cost function of signature fn(edit_symbol, a, b) -> float, where the returned value, between 0 and 1, is the weight that should be assigned to a particular edit in the weighted WER calculation. In the case of insertions and deletions, either of a or b may be None. In the case of substitutions, a and b will never be None.

  • weight_name (str) – Prefix to be prepended to each metric name (e.g. xxx_wer)

append(*args, **kwargs)[source]

Append function, which should NOT be used for the weighted error rate stats. Please append to the specified base_stats instead.

WeightedErrorRateStats reuses the scores from the base ErrorRateStats class.

Parameters:
  • *args – Ignored.

  • **kwargs – Ignored.

summarize(field=None)[source]

Returns a dict containing some detailed WER statistics after weighting every edit with a weight determined by cost_function (returning 0.0 for no error, 1.0 for the default error behavior, and anything in between).

Does not require summarize() to have been called.

Full set of fields, each of which are prepended with `<weight_name_specified_at_init>_`: - wer: Weighted WER (ratio *100) - insertions: Weighted insertions - substitutions: Weighted substitutions - deletions: Weighted deletions - num_edits: Sum of weighted insertions/substitutions/deletions

Additionally, a scores list is populated by this function for each pair of sentences. Each entry of that list is a dict, with the fields: - key: the ID of the utterance. - WER, insertions, substitutions, deletions, num_edits with

the same semantics as described above, but at sentence level rather than global.

Parameters:

field (str, optional) – The field to return, if you are only interested in one of them. If specified, a single float is returned, otherwise, a dict is.

Returns:

  • dict from str to float, if field is None – A dictionary of the fields documented above.

  • float, if field is not None – The single field selected by field.

write_stats(filestream)[source]

Write all relevant info to file; here, only the weighted info as returned by summarize. See write_stats().

class speechbrain.utils.metric_stats.EmbeddingErrorRateSimilarity(embedding_function: Callable[[str], Tensor | None], low_similarity_weight: float, high_similarity_weight: float, threshold: float)[source]

Bases: object

Implements the similarity function from the EmbER metric as defined by https://www.isca-archive.org/interspeech_2022/roux22_interspeech.pdf

This metric involves a dictionary to map a token to a single word embedding. Substitutions in the WER get weighted down when the embeddings are similar enough. The goal is to reduce the impact of substitution errors with small semantic impact. Only substitution errors get weighted.

This is done by computing the cosine similarity between the two embeddings, then weighing the substitution with low_similarity_weight if similarity >= threshold or with high_similarity_weight otherwise (e.g. a substitution with high similarity could be weighted down to matter 10% as much as a substitution with low similarity).

Note

The cited paper recommended (1.0, 0.1, 0.4) as defaults for fastTexst French embeddings, chosen empirically. When using different embeddings, you might want to test other values; thus we don’t provide defaults.

Parameters:
  • embedding_function (Callable[[str], Optional[torch.Tensor]]) – Function that returns an embedding (as a torch.Tensor) from a word. If no corresponding embedding could be found for the word, should return None. In that case, low_similarity_weight will be chosen.

  • low_similarity_weight (float) – Weight applied to the substitution if cosine_similarity < threshold.

  • high_similarity_weight (float) – Weight applied to the substitution if cosine_similarity >= threshold.

  • threshold (float) – Cosine similarity threshold used to select by how much a substitution error should be weighed for this word.

__call__(edit_symbol: str, a: str | None, b: str | None) float[source]

Returns the weight that should be associated with a specific edit in the WER calculation.

Compatible candidate for the cost function of WeightedErrorRateStats so an instance of this class can be passed as a cost_function.

Parameters:
  • edit_symbol (str) – Edit symbol as assigned by the WER functions, see EDIT_SYMBOLS.

  • a (str, optional) – First word to compare (if present)

  • b (str, optional) – Second word to compare (if present)

Returns:

Weight to assign to the edit. For actual edits, either low_similarity_weight or high_similarity_weight depending on the embedding distance and threshold.

Return type:

float

class speechbrain.utils.metric_stats.BinaryMetricStats(positive_label=1)[source]

Bases: MetricStats

Tracks binary metrics, such as precision, recall, F1, EER, etc.

clear()[source]

Clears the stored metrics.

append(ids, scores, labels)[source]

Appends scores and labels to internal lists.

Does not compute metrics until time of summary, since automatic thresholds (e.g., EER) need full set of scores.

Parameters:
  • ids (list) – The string ids for the samples.

  • scores (list) – The scores corresponding to the ids.

  • labels (list) – The labels corresponding to the ids.

summarize(field=None, threshold=None, max_samples=None, beta=1, eps=1e-08)[source]

Compute statistics using a full set of scores.

Full set of fields:
  • TP - True Positive

  • TN - True Negative

  • FP - False Positive

  • FN - False Negative

  • FAR - False Acceptance Rate

  • FRR - False Rejection Rate

  • DER - Detection Error Rate (EER if no threshold passed)

  • threshold - threshold (EER threshold if no threshold passed)

  • precision - Precision (positive predictive value)

  • recall - Recall (sensitivity)

  • F-score - Balance of precision and recall (equal if beta=1)

  • MCC - Matthews Correlation Coefficient

Parameters:
  • field (str) – A key for selecting a single statistic. If not provided, a dict with all statistics is returned.

  • threshold (float) – If no threshold is provided, equal error rate is used.

  • max_samples (float) – How many samples to keep for positive/negative scores. If no max_samples is provided, all scores are kept. Only effective when threshold is None.

  • beta (float) – How much to weight precision vs recall in F-score. Default of 1. is equal weight, while higher values weight recall higher, and lower values weight precision higher.

  • eps (float) – A small value to avoid dividing by zero.

Returns:

if field is specified, only returns the score for that field. if field is None, returns the full set of fields.

Return type:

summary

speechbrain.utils.metric_stats.EER(positive_scores, negative_scores)[source]

Computes the EER (and its threshold).

Parameters:
  • positive_scores (torch.tensor) – The scores from entries of the same class.

  • negative_scores (torch.tensor) – The scores from entries of different classes.

Returns:

  • EER (float) – The EER score.

  • threshold (float) – The corresponding threshold for the EER score.

Example

>>> positive_scores = torch.tensor([0.6, 0.7, 0.8, 0.5])
>>> negative_scores = torch.tensor([0.4, 0.3, 0.2, 0.1])
>>> val_eer, threshold = EER(positive_scores, negative_scores)
>>> val_eer
0.0
speechbrain.utils.metric_stats.minDCF(positive_scores, negative_scores, c_miss=1.0, c_fa=1.0, p_target=0.01)[source]

Computes the minDCF metric normally used to evaluate speaker verification systems. The min_DCF is the minimum of the following C_det function computed within the defined threshold range:

C_det = c_miss * p_miss * p_target + c_fa * p_fa * (1 -p_target)

where p_miss is the missing probability and p_fa is the probability of having a false alarm.

Parameters:
  • positive_scores (torch.tensor) – The scores from entries of the same class.

  • negative_scores (torch.tensor) – The scores from entries of different classes.

  • c_miss (float) – Cost assigned to a missing error (default 1.0).

  • c_fa (float) – Cost assigned to a false alarm (default 1.0).

  • p_target (float) – Prior probability of having a target (default 0.01).

Returns:

  • minDCF (float) – The minDCF score.

  • threshold (float) – The corresponding threshold for the minDCF score.

Example

>>> positive_scores = torch.tensor([0.6, 0.7, 0.8, 0.5])
>>> negative_scores = torch.tensor([0.4, 0.3, 0.2, 0.1])
>>> val_minDCF, threshold = minDCF(positive_scores, negative_scores)
>>> val_minDCF
0.0
class speechbrain.utils.metric_stats.ClassificationStats[source]

Bases: MetricStats

Computes statistics pertaining to multi-label classification tasks, as well as tasks that can be loosely interpreted as such for the purpose of evaluations.

Example

>>> import sys
>>> from speechbrain.utils.metric_stats import ClassificationStats
>>> cs = ClassificationStats()
>>> cs.append(
...     ids=["ITEM1", "ITEM2", "ITEM3", "ITEM4"],
...     predictions=[
...         "M EY K AH",
...         "T EY K",
...         "B AE D",
...         "M EY K",
...     ],
...     targets=[
...         "M EY K",
...         "T EY K",
...         "B AE D",
...         "M EY K",
...     ],
...     categories=[
...         "make",
...         "take",
...         "bad",
...         "make"
...     ]
... )
>>> cs.write_stats(sys.stdout)
Overall Accuracy: 75%

Class-Wise Accuracy
-------------------
bad -> B AE D : 1 / 1 (100.00%)
make -> M EY K: 1 / 2 (50.00%)
take -> T EY K: 1 / 1 (100.00%)

Confusion
---------
Target: bad -> B AE D
  -> B AE D   : 1 / 1 (100.00%)
Target: make -> M EY K
  -> M EY K   : 1 / 2 (50.00%)
  -> M EY K AH: 1 / 2 (50.00%)
Target: take -> T EY K
  -> T EY K   : 1 / 1 (100.00%)
>>> summary = cs.summarize()
>>> summary['accuracy']
0.75
>>> summary['classwise_stats'][('bad', 'B AE D')]
{'total': 1.0, 'correct': 1.0, 'accuracy': 1.0}
>>> summary['classwise_stats'][('make', 'M EY K')]
{'total': 2.0, 'correct': 1.0, 'accuracy': 0.5}
>>> summary['keys']
[('bad', 'B AE D'), ('make', 'M EY K'), ('take', 'T EY K')]
>>> summary['predictions']
['B AE D', 'M EY K', 'M EY K AH', 'T EY K']
>>> summary['classwise_total']
{('bad', 'B AE D'): 1.0, ('make', 'M EY K'): 2.0, ('take', 'T EY K'): 1.0}
>>> summary['classwise_correct']
{('bad', 'B AE D'): 1.0, ('make', 'M EY K'): 1.0, ('take', 'T EY K'): 1.0}
>>> summary['classwise_accuracy']
{('bad', 'B AE D'): 1.0, ('make', 'M EY K'): 0.5, ('take', 'T EY K'): 1.0}
append(ids, predictions, targets, categories=None)[source]

Appends inputs, predictions and targets to internal lists

Parameters:
  • ids (list) – the string IDs for the samples

  • predictions (list) – the model’s predictions (human-interpretable, preferably strings)

  • targets (list) – the ground truths (human-interpretable, preferably strings)

  • categories (list) – an additional way to classify training samples. If available, the categories will be combined with targets

summarize(field=None)[source]

Summarize the classification metric scores

The following statistics are computed:

accuracy: the overall accuracy (# correct / # total) confusion_matrix: a dictionary of type

{(target, prediction): num_entries} representing the confusion matrix

classwise_stats: computes the total number of samples,

the number of correct classifications and accuracy for each class

keys: all available class keys, which can be either target classes

or (category, target) tuples

predictions: all available predictions all predictions the model

has made

Parameters:

field (str) – If provided, only returns selected statistic. If not, returns all computed statistics.

Returns:

Returns a float if field is provided, otherwise returns a dictionary containing all computed stats.

Return type:

float or dict

clear()[source]

Clears the collected statistics

write_stats(filestream)[source]

Outputs the stats to the specified filestream in a human-readable format

Parameters:

filestream (file) – a file-like object

class speechbrain.utils.metric_stats.MultiMetricStats(metric, n_jobs=1, batch_eval=False)[source]

Bases: object

A wrapper that evaluates multiple metrics simultaneously

Parameters:
  • metric (function) – The function to use to compute the relevant metrics. Should take at least two arguments (predictions and targets) and can optionally take the relative lengths of either or both arguments. The function should return a dict or a namedtuple

  • n_jobs (int) – The number of jobs to use for computing the metric. If this is more than one, every sample is processed individually, otherwise the whole batch is passed at once.

  • batch_eval (bool) – When True it feeds the evaluation metric with the batched input. When False and n_jobs=1, it performs metric evaluation one-by-one in a sequential way. When False and n_jobs>1, the evaluation runs in parallel over the different inputs using joblib.

Example

>>> def metric(a, b):
...    return {
...        "sum": a + b,
...        "diff": a - b,
...        "sum_sq": a**2 + b**2
...    }
>>> multi_metric = MultiMetricStats(metric, batch_eval=True)
>>> multi_metric.append([1, 2], a=torch.tensor([2.0, 1.0]), b=torch.tensor([1.0, 2.0]))
>>> multi_metric.append([3, 4], a=torch.tensor([4.0, 5.0]), b=torch.tensor([0.0, 1.0]))
>>> multi_metric.append([5, 6], a=torch.tensor([2.0, 4.0]), b=torch.tensor([4.0, 2.0]))
>>> multi_metric.append([7, 8], a=torch.tensor([2.0, 4.0]), b=torch.tensor([4.0, 2.0]))
>>> multi_metric.summarize() 
{'sum': {'average': 5.0,
  'min_score': 3.0,
  'min_id': 1,
  'max_score': 6.0,
  'max_id': 4},
 'diff': {'average': 1.0,
  'min_score': -2.0,
  'min_id': 5,
  'max_score': 4.0,
  'max_id': 3},
 'sum_sq': {'average': 16.5,
  'min_score': 5.0,
  'min_id': 1,
  'max_score': 26.0,
  'max_id': 4}}
>>> multi_metric.summarize(flat=True) 
{'sum_average': 5.0,
 'sum_min_score': 3.0,
 'sum_min_id': 1,
 'sum_max_score': 6.0,
 'sum_max_id': 4,
 'diff_average': 1.0,
 'diff_min_score': -2.0,
 'diff_min_id': 5,
 'diff_max_score': 4.0,
 'diff_max_id': 3,
 'sum_sq_average': 16.5,
 'sum_sq_min_score': 5.0,
 'sum_sq_min_id': 1,
 'sum_sq_max_score': 26.0,
 'sum_sq_max_id': 4}
append(ids, *args, **kwargs)[source]

Store a particular set of metric scores.

Parameters:
  • ids (list) – List of ids corresponding to utterances.

  • *args (tuple) – Arguments to pass to the metric function.

  • **kwargs (dict) – Arguments to pass to the metric function.

eval_simple(*args, **kwargs)[source]

Evaluates the metric in a simple, sequential manner

summarize(field=None, flat=False)[source]

Summarize the metric scores, returning relevant stats.

Parameters:
  • field (str) – If provided, only returns selected statistic. If not, returns all computed statistics.

  • flat (bool) – whether to flatten the dictionary

Returns:

Returns a dictionary of all computed stats

Return type:

dict