Module laplace.utils.metrics

Classes

class RunningNLLMetric (ignore_index=-100)

NLL metrics that

Parameters

ignore_index : int, default = -100
which class label to ignore when computing the NLL loss

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Ancestors

  • torchmetrics.metric.Metric
  • torch.nn.modules.module.Module
  • abc.ABC

Methods

def update(self, probs: torch.Tensor, targets: torch.Tensor) ‑> None

Parameters

probs : torch.Tensor
probability tensor of shape (…, n_classes)
targets : torch.Tensor
integer tensor of shape (…)
def compute(self)

Override this method to compute the final metric value.

This method will automatically synchronize state variables when running in distributed backend.