Module laplace.utils

Sub-modules

laplace.utils.enums
laplace.utils.feature_extractor
laplace.utils.matrix
laplace.utils.metrics
laplace.utils.subnetmask
laplace.utils.swag
laplace.utils.utils

Functions

def get_nll(out_dist: torch.Tensor, targets: torch.Tensor) ‑> torch.Tensor
def validate(laplace: BaseLaplace, val_loader: DataLoader, loss: torchmetrics.Metric | Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor], pred_type: PredType | str = PredType.GLM, link_approx: LinkApprox | str = LinkApprox.PROBIT, n_samples: int = 100, dict_key_y: str = 'labels') ‑> float
def parameters_per_layer(model: nn.Module) ‑> list[int]

Get number of parameters per layer.

Parameters

model : torch.nn.Module
 

Returns

params_per_layer : list[int]
 
def invsqrt_precision(M: torch.Tensor) ‑> torch.Tensor

Compute M^{-0.5} as a tridiagonal matrix.

Parameters

M : torch.Tensor
 

Returns

M_invsqrt : torch.Tensor
 
def kron(t1: torch.Tensor, t2: torch.Tensor) ‑> torch.Tensor

Computes the Kronecker product between two tensors.

Parameters

t1 : torch.Tensor
 
t2 : torch.Tensor
 

Returns

kron_product : torch.Tensor
 
def diagonal_add_scalar(X: torch.Tensor, value: torch.Tensor) ‑> torch.Tensor

Add scalar value value to diagonal of X.

Parameters

X : torch.Tensor
 
value : torch.Tensor or float
 

Returns

X_add_scalar : torch.Tensor
 
def symeig(M: torch.Tensor) ‑> tuple[torch.Tensor, torch.Tensor]

Symetric eigendecomposition avoiding failure cases by adding and removing jitter to the diagonal.

Parameters

M : torch.Tensor
 

Returns

L : torch.Tensor
eigenvalues
W : torch.Tensor
eigenvectors
def block_diag(blocks: list[torch.Tensor]) ‑> torch.Tensor

Compose block-diagonal matrix of individual blocks.

Parameters

blocks : list[torch.Tensor]
 

Returns

M : torch.Tensor
 
def normal_samples(mean: torch.Tensor, var: torch.Tensor, n_samples: int, generator: torch.Generator | None = None) ‑> torch.Tensor

Produce samples from a batch of Normal distributions either parameterized by a diagonal or full covariance given by var.

Parameters

mean : torch.Tensor
(batch_size, output_dim)
var : torch.Tensor
(co)variance of the Normal distribution (batch_size, output_dim, output_dim) or (batch_size, output_dim)
generator : torch.Generator
random number generator
def _is_batchnorm(module: nn.Module) ‑> bool
def _is_valid_scalar(scalar: float | int | torch.Tensor) ‑> bool
def expand_prior_precision(prior_prec: torch.Tensor, model: nn.Module) ‑> torch.Tensor

Expand prior precision to match the shape of the model parameters.

Parameters

prior_prec : torch.Tensor 1-dimensional
prior precision
model : torch.nn.Module
torch model with parameters that are regularized by prior_prec

Returns

expanded_prior_prec : torch.Tensor
expanded prior precision has the same shape as model parameters
def fix_prior_prec_structure(prior_prec_init: torch.Tensor, prior_structure: PriorStructure | str, n_layers: int, n_params: int, device: torch.device) ‑> torch.Tensor

Create a tensor of prior precision with the correct shape, depending on the choice of the prior structure type.

Parameters

prior_prec_init : torch.Tensor
the initial prior precision tensor (could be scalar)
prior_structure : PriorStructure | str
the choice of the prior structure type
n_layers : int
 
n_params : int
 
device : torch.device
 

Returns

correct_prior_precision : torch.Tensor
 
def fit_diagonal_swag_var(model: nn.Module, train_loader: DataLoader, criterion: nn.CrossEntropyLoss | nn.MSELoss, n_snapshots_total: int = 40, snapshot_freq: int = 1, lr: float = 0.01, momentum: float = 0.9, weight_decay: float = 0.0003, min_var: float = 1e-30) ‑> torch.Tensor

Fit diagonal SWAG [1], which estimates marginal variances of model parameters by computing the first and second moment of SGD iterates with a large learning rate.

Implementation partly adapted from: - https://github.com/wjmaddox/swa_gaussian/blob/master/swag/posteriors/swag.py - https://github.com/wjmaddox/swa_gaussian/blob/master/experiments/train/run_swag.py

References

[1] Maddox, W., Garipov, T., Izmailov, P., Vetrov, D., Wilson, AG. A Simple Baseline for Bayesian Uncertainty in Deep Learning. NeurIPS 2019.

Parameters

model : torch.nn.Module
 
train_loader : torch.data.utils.DataLoader
training data loader to use for snapshot collection
criterion : torch.nn.CrossEntropyLoss or torch.nn.MSELoss
loss function to use for snapshot collection
n_snapshots_total : int
total number of model snapshots to collect
snapshot_freq : int
snapshot collection frequency (in epochs)
lr : float
SGD learning rate for collecting snapshots
momentum : float
SGD momentum
weight_decay : float
SGD weight decay
min_var : float
minimum parameter variance to clamp to (for numerical stability)

Returns

param_variances : torch.Tensor
vector of marginal variances for each model parameter

Classes

class FeatureExtractor (model: nn.Module, last_layer_name: str | None = None, enable_backprop: bool = False, feature_reduction: FeatureReduction | str | None = None)

Feature extractor for a PyTorch neural network. A wrapper which can return the output of the penultimate layer in addition to the output of the last layer for each forward pass. If the name of the last layer is not known, it can determine it automatically. It assumes that the last layer is linear and that for every forward pass the last layer is the same. If the name of the last layer is known, it can be passed as a parameter at initilization; this is the safest way to use this class. Based on https://gist.github.com/fkodom/27ed045c9051a39102e8bcf4ce31df76.

Parameters

model : torch.nn.Module
PyTorch model
last_layer_name : str, default=None
if the name of the last layer is already known, otherwise it will be determined automatically.
enable_backprop : bool, default=False
whether to enable backprop through the feature extactor to get the gradients of the inputs. Useful for e.g. Bayesian optimization.
feature_reduction : FeatureReduction or str, default=None
when the last-layer features is a tensor of dim >= 3, this tells how to reduce it into a dim-2 tensor. E.g. in LLMs for non-language modeling problems, the penultultimate output is a tensor of shape (batch_size, seq_len, embd_dim). But the last layer maps (batch_size, embd_dim) to (batch_size, n_classes). Note: Make sure that this option faithfully reflects the reduction in the model definition. When inputting a string, available options are {'pick_first', 'pick_last', 'average'}.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Ancestors

  • torch.nn.modules.module.Module

Class variables

var dump_patches : bool
var training : bool
var call_super_init : bool

Methods

def forward(self, x: torch.Tensor | MutableMapping[str, torch.Tensor | Any]) ‑> Callable[..., Any]

Forward pass. If the last layer is not known yet, it will be determined when this function is called for the first time.

Parameters

x : torch.Tensor or a dict-like object containing the input tensors
one batch of data to use as input for the forward pass
def forward_with_features(self, x: torch.Tensor | MutableMapping[str, torch.Tensor | Any]) ‑> tuple[torch.Tensor, torch.Tensor]

Forward pass which returns the output of the penultimate layer along with the output of the last layer. If the last layer is not known yet, it will be determined when this function is called for the first time.

Parameters

x : torch.Tensor or a dict-like object containing the input tensors
one batch of data to use as input for the forward pass
def set_last_layer(self, last_layer_name: str) ‑> None

Set the last layer of the model by its name. This sets the forward hook to get the output of the penultimate layer.

Parameters

last_layer_name : str
the name of the last layer (fixed in model.named_modules()).
def find_last_layer(self, x: torch.Tensor | MutableMapping[str, torch.Tensor | Any]) ‑> torch.Tensor

Automatically determines the last layer of the model with one forward pass. It assumes that the last layer is the same for every forward pass and that it is an instance of torch.nn.Linear. Might not work with every architecture, but is tested with all PyTorch torchvision classification models (besides SqueezeNet, which has no linear last layer).

Parameters

x : torch.Tensor or dict-like object containing the input tensors
one batch of data to use as input for the forward pass
class Kron (kfacs: list[tuple[torch.Tensor] | torch.Tensor])

Kronecker factored approximate curvature representation for a corresponding neural network. Each element in kfacs is either a tuple or single matrix. A tuple represents two Kronecker factors Q, and H and a single element is just a full block Hessian approximation.

Parameters

kfacs : list[Iterable[torch.Tensor] | torch.Tensor]
each element in the list is a tuple of two Kronecker factors Q, H or a single matrix approximating the Hessian (in case of bias, for example)

Static methods

def init_from_model(model: nn.Module | Iterable[nn.Parameter], device: torch.device) ‑> Kron

Initialize Kronecker factors based on a models architecture.

Parameters

model : nn.Module or iterable of parameters, e.g. model.parameters()
 
device : torch.device
 

Returns

kron : Kron
 

Methods

def decompose(self, damping: bool = False) ‑> KronDecomposed

Eigendecompose Kronecker factors and turn into KronDecomposed. Parameters


damping : bool
use damping

Returns

kron_decomposed : KronDecomposed
 
def bmm(self, W: torch.Tensor, exponent: float = 1) ‑> torch.Tensor

Batched matrix multiplication with the Kronecker factors. If Kron is H, we compute H @ W. This is useful for computing the predictive or a regularization based on Kronecker factors as in continual learning.

Parameters

W : torch.Tensor
matrix (batch, classes, params)
exponent : float, default=1
only can be 1 for Kron, requires KronDecomposed for other exponent values of the Kronecker factors.

Returns

SW : torch.Tensor
result (batch, classes, params)
def logdet(self) ‑> torch.Tensor

Compute log determinant of the Kronecker factors and sums them up. This corresponds to the log determinant of the entire Hessian approximation.

Returns

logdet : torch.Tensor
 
def diag(self) ‑> torch.Tensor

Extract diagonal of the entire Kronecker factorization.

Returns

diag : torch.Tensor
 
def to_matrix(self) ‑> torch.Tensor

Make the Kronecker factorization dense by computing the kronecker product. Warning: this should only be used for testing purposes as it will allocate large amounts of memory for big architectures.

Returns

block_diag : torch.Tensor
 
class KronDecomposed (eigenvectors: list[tuple[torch.Tensor]], eigenvalues: list[tuple[torch.Tensor]], deltas: torch.Tensor | None = None, damping: bool = False)

Decomposed Kronecker factored approximate curvature representation for a corresponding neural network. Each matrix in Kron is decomposed to obtain KronDecomposed. Front-loading decomposition allows cheap repeated computation of inverses and log determinants. In contrast to Kron, we can add scalar or layerwise scalars but we cannot add other Kron or KronDecomposed anymore.

Parameters

eigenvectors : list[Tuple[torch.Tensor]]
eigenvectors corresponding to matrices in a corresponding Kron
eigenvalues : list[Tuple[torch.Tensor]]
eigenvalues corresponding to matrices in a corresponding Kron
deltas : torch.Tensor
addend for each group of Kronecker factors representing, for example, a prior precision
dampen : bool, default=False
use dampen approximation mixing prior and Kron partially multiplicatively

Methods

def detach(self) ‑> KronDecomposed
def logdet(self) ‑> torch.Tensor

Compute log determinant of the Kronecker factors and sums them up. This corresponds to the log determinant of the entire Hessian approximation. In contrast to Kron.logdet(), additive deltas corresponding to prior precisions are added.

Returns

logdet : torch.Tensor
 
def inv_square_form(self, W: torch.Tensor) ‑> torch.Tensor
def bmm(self, W: torch.Tensor, exponent: float = -1) ‑> torch.Tensor

Batched matrix multiplication with the decomposed Kronecker factors. This is useful for computing the predictive or a regularization loss. Compared to Kron.bmm(), a prior can be added here in form of deltas and the exponent can be other than just 1. Computes H^{exponent} W.

Parameters

W : torch.Tensor
matrix (batch, classes, params)
exponent : float, default=1
 

Returns

SW : torch.Tensor
result (batch, classes, params)
def diag(self, exponent: float = 1) ‑> torch.Tensor

Extract diagonal of the entire decomposed Kronecker factorization.

Parameters

exponent : float, default=1
exponent of the Kronecker factorization

Returns

diag : torch.Tensor
 
def to_matrix(self, exponent: float = 1) ‑> torch.Tensor

Make the Kronecker factorization dense by computing the kronecker product. Warning: this should only be used for testing purposes as it will allocate large amounts of memory for big architectures.

Parameters

exponent : float, default=1
exponent of the Kronecker factorization

Returns

block_diag : torch.Tensor
 
class SubnetMask (model: nn.Module)

Baseclass for all subnetwork masks in this library (for subnetwork Laplace).

Parameters

model : torch.nn.Module
 

Subclasses

Instance variables

var indices : torch.LongTensor
var n_params_subnet : int

Methods

def convert_subnet_mask_to_indices(self, subnet_mask: torch.Tensor) ‑> torch.LongTensor

Converts a subnetwork mask into subnetwork indices.

Parameters

subnet_mask : torch.Tensor
a binary vector of size (n_params) where 1s locate the subnetwork parameters within the vectorized model parameters (i.e. torch.nn.utils.parameters_to_vector(model.parameters()))

Returns

subnet_mask_indices : torch.LongTensor
a vector of indices of the vectorized model parameters (i.e. torch.nn.utils.parameters_to_vector(model.parameters())) that define the subnetwork
def select(self, train_loader: DataLoader | None = None) ‑> torch.LongTensor

Select the subnetwork mask.

Parameters

train_loader : torch.data.utils.DataLoader, default=None
each iterate is a training batch (X, y); train_loader.dataset needs to be set to access N, size of the data set

Returns

subnet_mask_indices : torch.LongTensor
a vector of indices of the vectorized model parameters (i.e. torch.nn.utils.parameters_to_vector(model.parameters())) that define the subnetwork
def get_subnet_mask(self, train_loader: DataLoader) ‑> torch.Tensor

Get the subnetwork mask.

Parameters

train_loader : torch.data.utils.DataLoader
each iterate is a training batch (X, y); train_loader.dataset needs to be set to access N, size of the data set

Returns

subnet_mask : torch.Tensor
a binary vector of size (n_params) where 1s locate the subnetwork parameters within the vectorized model parameters (i.e. torch.nn.utils.parameters_to_vector(model.parameters()))
class RandomSubnetMask (model: nn.Module, n_params_subnet: int)

Subnetwork mask of parameters sampled uniformly at random.

Ancestors

  • laplace.utils.subnetmask.ScoreBasedSubnetMask
  • SubnetMask

Methods

def compute_param_scores(self, train_loader: DataLoader) ‑> torch.Tensor

Inherited members

class LargestMagnitudeSubnetMask (model: nn.Module, n_params_subnet: int)

Subnetwork mask identifying the parameters with the largest magnitude.

Ancestors

  • laplace.utils.subnetmask.ScoreBasedSubnetMask
  • SubnetMask

Methods

def compute_param_scores(self, train_loader: DataLoader) ‑> torch.Tensor

Inherited members

class LargestVarianceDiagLaplaceSubnetMask (model: nn.Module, n_params_subnet: int, diag_laplace_model: DiagLaplace)

Subnetwork mask identifying the parameters with the largest marginal variances (estimated using a diagonal Laplace approximation over all model parameters).

Parameters

model : torch.nn.Module
 
n_params_subnet : int
number of parameters in the subnetwork (i.e. number of top-scoring parameters to select)
diag_laplace_model : DiagLaplace
diagonal Laplace model to use for variance estimation

Ancestors

  • laplace.utils.subnetmask.ScoreBasedSubnetMask
  • SubnetMask

Methods

def compute_param_scores(self, train_loader: DataLoader) ‑> torch.Tensor

Inherited members

class LargestVarianceSWAGSubnetMask (model: nn.Module, n_params_subnet: int, likelihood: Likelihood | str = Likelihood.CLASSIFICATION, swag_n_snapshots: int = 40, swag_snapshot_freq: int = 1, swag_lr: float = 0.01)

Subnetwork mask identifying the parameters with the largest marginal variances (estimated using diagonal SWAG over all model parameters).

Parameters

model : torch.nn.Module
 
n_params_subnet : int
number of parameters in the subnetwork (i.e. number of top-scoring parameters to select)
likelihood : str
'classification' or 'regression'
swag_n_snapshots : int
number of model snapshots to collect for SWAG
swag_snapshot_freq : int
SWAG snapshot collection frequency (in epochs)
swag_lr : float
learning rate for SWAG snapshot collection

Ancestors

  • laplace.utils.subnetmask.ScoreBasedSubnetMask
  • SubnetMask

Methods

def compute_param_scores(self, train_loader: DataLoader) ‑> torch.Tensor

Inherited members

class ParamNameSubnetMask (model: nn.Module, parameter_names: list[str])

Subnetwork mask corresponding to the specified parameters of the neural network.

Parameters

model : torch.nn.Module
 
parameter_names : List[str]
list of names of the parameters (as in model.named_parameters()) that define the subnetwork

Ancestors

Methods

def get_subnet_mask(self, train_loader: DataLoader) ‑> torch.Tensor

Get the subnetwork mask identifying the specified parameters.

Inherited members

class ModuleNameSubnetMask (model: nn.Module, module_names: list[str])

Subnetwork mask corresponding to the specified modules of the neural network.

Parameters

model : torch.nn.Module
 
parameter_names : List[str]
list of names of the modules (as in model.named_modules()) that define the subnetwork; the modules cannot have children, i.e. need to be leaf modules

Ancestors

Subclasses

Methods

def get_subnet_mask(self, train_loader: DataLoader) ‑> torch.Tensor

Get the subnetwork mask identifying the specified modules.

Inherited members

class LastLayerSubnetMask (model: nn.Module, last_layer_name: str | None = None)

Subnetwork mask corresponding to the last layer of the neural network.

Parameters

model : torch.nn.Module
 
last_layer_name : str, default=None
name of the model's last layer, if None it will be determined automatically

Ancestors

Methods

def get_subnet_mask(self, train_loader: DataLoader) ‑> torch.Tensor

Get the subnetwork mask identifying the last layer.

Inherited members

class RunningNLLMetric (ignore_index: int = -100)

NLL metrics that

Parameters

ignore_index : int, default = -100
which class label to ignore when computing the NLL loss

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Ancestors

  • torchmetrics.metric.Metric
  • torch.nn.modules.module.Module
  • abc.ABC

Class variables

var is_differentiable : Optional[bool]
var higher_is_better : Optional[bool]
var full_state_update : Optional[bool]
var plot_lower_bound : Optional[float]
var plot_upper_bound : Optional[float]
var plot_legend_name : Optional[str]

Methods

def update(self, probs: torch.Tensor, targets: torch.Tensor) ‑> None

Parameters

probs : torch.Tensor
probability tensor of shape (…, n_classes)
targets : torch.Tensor
integer tensor of shape (…)
def compute(self) ‑> torch.Tensor

Override this method to compute the final metric value.

This method will automatically synchronize state variables when running in distributed backend.

class SubsetOfWeights (value, names=None, *, module=None, qualname=None, type=None, start=1)

An enumeration.

Ancestors

  • builtins.str
  • enum.Enum

Class variables

var ALL
var LAST_LAYER
var SUBNETWORK
class HessianStructure (value, names=None, *, module=None, qualname=None, type=None, start=1)

An enumeration.

Ancestors

  • builtins.str
  • enum.Enum

Class variables

var FULL
var KRON
var DIAG
var LOWRANK
class Likelihood (value, names=None, *, module=None, qualname=None, type=None, start=1)

An enumeration.

Ancestors

  • builtins.str
  • enum.Enum

Class variables

var REGRESSION
var CLASSIFICATION
var REWARD_MODELING
class PredType (value, names=None, *, module=None, qualname=None, type=None, start=1)

An enumeration.

Ancestors

  • builtins.str
  • enum.Enum

Class variables

var GLM
var NN
class LinkApprox (value, names=None, *, module=None, qualname=None, type=None, start=1)

An enumeration.

Ancestors

  • builtins.str
  • enum.Enum

Class variables

var MC
var PROBIT
var BRIDGE
var BRIDGE_NORM
class TuningMethod (value, names=None, *, module=None, qualname=None, type=None, start=1)

An enumeration.

Ancestors

  • builtins.str
  • enum.Enum

Class variables

var MARGLIK
var GRIDSEARCH
class PriorStructure (value, names=None, *, module=None, qualname=None, type=None, start=1)

An enumeration.

Ancestors

  • builtins.str
  • enum.Enum

Class variables

var SCALAR
var DIAG
var LAYERWISE