Module laplace.utils

Sub-modules

laplace.utils.feature_extractor
laplace.utils.matrix
laplace.utils.subnetmask
laplace.utils.swag
laplace.utils.utils

Functions

def get_nll(out_dist, targets)
def validate(laplace, val_loader, pred_type='glm', link_approx='probit', n_samples=100)
def parameters_per_layer(model)

Get number of parameters per layer.

Parameters

model : torch.nn.Module
 

Returns

params_per_layer : list[int]
 
def invsqrt_precision(M)

Compute M^{-0.5} as a tridiagonal matrix.

Parameters

M : torch.Tensor
 

Returns

M_invsqrt : torch.Tensor
 
def kron(t1, t2)

Computes the Kronecker product between two tensors.

Parameters

t1 : torch.Tensor
 
t2 : torch.Tensor
 

Returns

kron_product : torch.Tensor
 
def diagonal_add_scalar(X, value)

Add scalar value value to diagonal of X.

Parameters

X : torch.Tensor
 
value : torch.Tensor or float
 

Returns

X_add_scalar : torch.Tensor
 
def symeig(M)

Symetric eigendecomposition avoiding failure cases by adding and removing jitter to the diagonal.

Parameters

M : torch.Tensor
 

Returns

L : torch.Tensor
eigenvalues
W : torch.Tensor
eigenvectors
def block_diag(blocks)

Compose block-diagonal matrix of individual blocks.

Parameters

blocks : list[torch.Tensor]
 

Returns

M : torch.Tensor
 
def expand_prior_precision(prior_prec, model)

Expand prior precision to match the shape of the model parameters.

Parameters

prior_prec : torch.Tensor 1-dimensional
prior precision
model : torch.nn.Module
torch model with parameters that are regularized by prior_prec

Returns

expanded_prior_prec : torch.Tensor
expanded prior precision has the same shape as model parameters
def fit_diagonal_swag_var(model, train_loader, criterion, n_snapshots_total=40, snapshot_freq=1, lr=0.01, momentum=0.9, weight_decay=0.0003, min_var=1e-30)

Fit diagonal SWAG [1], which estimates marginal variances of model parameters by computing the first and second moment of SGD iterates with a large learning rate.

Implementation partly adapted from: - https://github.com/wjmaddox/swa_gaussian/blob/master/swag/posteriors/swag.py - https://github.com/wjmaddox/swa_gaussian/blob/master/experiments/train/run_swag.py

References

[1] Maddox, W., Garipov, T., Izmailov, P., Vetrov, D., Wilson, AG. A Simple Baseline for Bayesian Uncertainty in Deep Learning. NeurIPS 2019.

Parameters

model : torch.nn.Module
 
train_loader : torch.data.utils.DataLoader
training data loader to use for snapshot collection
criterion : torch.nn.CrossEntropyLoss or torch.nn.MSELoss
loss function to use for snapshot collection
n_snapshots_total : int
total number of model snapshots to collect
snapshot_freq : int
snapshot collection frequency (in epochs)
lr : float
SGD learning rate for collecting snapshots
momentum : float
SGD momentum
weight_decay : float
SGD weight decay
min_var : float
minimum parameter variance to clamp to (for numerical stability)

Returns

param_variances : torch.Tensor
vector of marginal variances for each model parameter

Classes

class FeatureExtractor (model: torch.nn.modules.module.Module, last_layer_name: Optional[str] = None)

Feature extractor for a PyTorch neural network. A wrapper which can return the output of the penultimate layer in addition to the output of the last layer for each forward pass. If the name of the last layer is not known, it can determine it automatically. It assumes that the last layer is linear and that for every forward pass the last layer is the same. If the name of the last layer is known, it can be passed as a parameter at initilization; this is the safest way to use this class. Based on https://gist.github.com/fkodom/27ed045c9051a39102e8bcf4ce31df76.

Parameters

model : torch.nn.Module
PyTorch model
last_layer_name : str, default=None
if the name of the last layer is already known, otherwise it will be determined automatically.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Ancestors

  • torch.nn.modules.module.Module

Class variables

var dump_patches : bool
var training : bool

Methods

def forward(self, x: torch.Tensor) ‑> torch.Tensor

Forward pass. If the last layer is not known yet, it will be determined when this function is called for the first time.

Parameters

x : torch.Tensor
one batch of data to use as input for the forward pass
def forward_with_features(self, x: torch.Tensor) ‑> Tuple[torch.Tensor, torch.Tensor]

Forward pass which returns the output of the penultimate layer along with the output of the last layer. If the last layer is not known yet, it will be determined when this function is called for the first time.

Parameters

x : torch.Tensor
one batch of data to use as input for the forward pass
def set_last_layer(self, last_layer_name: str) ‑> None

Set the last layer of the model by its name. This sets the forward hook to get the output of the penultimate layer.

Parameters

last_layer_name : str
the name of the last layer (fixed in model.named_modules()).
def find_last_layer(self, x: torch.Tensor) ‑> torch.Tensor

Automatically determines the last layer of the model with one forward pass. It assumes that the last layer is the same for every forward pass and that it is an instance of torch.nn.Linear. Might not work with every architecture, but is tested with all PyTorch torchvision classification models (besides SqueezeNet, which has no linear last layer).

Parameters

x : torch.Tensor
one batch of data to use as input for the forward pass
class Kron (kfacs)

Kronecker factored approximate curvature representation for a corresponding neural network. Each element in kfacs is either a tuple or single matrix. A tuple represents two Kronecker factors Q, and H and a single element is just a full block Hessian approximation.

Parameters

kfacs : list[Tuple]
each element in the list is a Tuple of two Kronecker factors Q, H or a single matrix approximating the Hessian (in case of bias, for example)

Static methods

def init_from_model(model, device)

Initialize Kronecker factors based on a models architecture.

Parameters

model : torch.nn.Module
 
device : torch.device
 

Returns

kron : Kron
 

Methods

def decompose(self, damping=False)

Eigendecompose Kronecker factors and turn into KronDecomposed. Parameters


damping : bool
use damping

Returns

kron_decomposed : KronDecomposed
 
def bmm(self, W: torch.Tensor, exponent: float = 1) ‑> torch.Tensor

Batched matrix multiplication with the Kronecker factors. If Kron is H, we compute H @ W. This is useful for computing the predictive or a regularization based on Kronecker factors as in continual learning.

Parameters

W : torch.Tensor
matrix (batch, classes, params)
exponent : float, default=1
only can be 1 for Kron, requires KronDecomposed for other exponent values of the Kronecker factors.

Returns

SW : torch.Tensor
result (batch, classes, params)
def logdet(self) ‑> torch.Tensor

Compute log determinant of the Kronecker factors and sums them up. This corresponds to the log determinant of the entire Hessian approximation.

Returns

logdet : torch.Tensor
 
def diag(self) ‑> torch.Tensor

Extract diagonal of the entire Kronecker factorization.

Returns

diag : torch.Tensor
 
def to_matrix(self) ‑> torch.Tensor

Make the Kronecker factorization dense by computing the kronecker product. Warning: this should only be used for testing purposes as it will allocate large amounts of memory for big architectures.

Returns

block_diag : torch.Tensor
 
class KronDecomposed (eigenvectors, eigenvalues, deltas=None, damping=False)

Decomposed Kronecker factored approximate curvature representation for a corresponding neural network. Each matrix in Kron is decomposed to obtain KronDecomposed. Front-loading decomposition allows cheap repeated computation of inverses and log determinants. In contrast to Kron, we can add scalar or layerwise scalars but we cannot add other Kron or KronDecomposed anymore.

Parameters

eigenvectors : list[Tuple[torch.Tensor]]
eigenvectors corresponding to matrices in a corresponding Kron
eigenvalues : list[Tuple[torch.Tensor]]
eigenvalues corresponding to matrices in a corresponding Kron
deltas : torch.Tensor
addend for each group of Kronecker factors representing, for example, a prior precision
dampen : bool, default=False
use dampen approximation mixing prior and Kron partially multiplicatively

Methods

def detach(self)
def logdet(self) ‑> torch.Tensor

Compute log determinant of the Kronecker factors and sums them up. This corresponds to the log determinant of the entire Hessian approximation. In contrast to Kron.logdet(), additive deltas corresponding to prior precisions are added.

Returns

logdet : torch.Tensor
 
def inv_square_form(self, W: torch.Tensor) ‑> torch.Tensor
def bmm(self, W: torch.Tensor, exponent: float = -1) ‑> torch.Tensor

Batched matrix multiplication with the decomposed Kronecker factors. This is useful for computing the predictive or a regularization loss. Compared to Kron.bmm(), a prior can be added here in form of deltas and the exponent can be other than just 1. Computes H^{exponent} W.

Parameters

W : torch.Tensor
matrix (batch, classes, params)
exponent : float, default=1
 

Returns

SW : torch.Tensor
result (batch, classes, params)
def to_matrix(self, exponent: float = 1) ‑> torch.Tensor

Make the Kronecker factorization dense by computing the kronecker product. Warning: this should only be used for testing purposes as it will allocate large amounts of memory for big architectures.

Returns

block_diag : torch.Tensor
 
class SubnetMask (model)

Baseclass for all subnetwork masks in this library (for subnetwork Laplace).

Parameters

model : torch.nn.Module
 

Subclasses

Instance variables

var indices
var n_params_subnet

Methods

def convert_subnet_mask_to_indices(self, subnet_mask)

Converts a subnetwork mask into subnetwork indices.

Parameters

subnet_mask : torch.Tensor
a binary vector of size (n_params) where 1s locate the subnetwork parameters within the vectorized model parameters (i.e. torch.nn.utils.parameters_to_vector(model.parameters()))

Returns

subnet_mask_indices : torch.LongTensor
a vector of indices of the vectorized model parameters (i.e. torch.nn.utils.parameters_to_vector(model.parameters())) that define the subnetwork
def select(self, train_loader=None)

Select the subnetwork mask.

Parameters

train_loader : torch.data.utils.DataLoader, default=None
each iterate is a training batch (X, y); train_loader.dataset needs to be set to access N, size of the data set

Returns

subnet_mask_indices : torch.LongTensor
a vector of indices of the vectorized model parameters (i.e. torch.nn.utils.parameters_to_vector(model.parameters())) that define the subnetwork
def get_subnet_mask(self, train_loader)

Get the subnetwork mask.

Parameters

train_loader : torch.data.utils.DataLoader
each iterate is a training batch (X, y); train_loader.dataset needs to be set to access N, size of the data set

Returns

subnet_mask : torch.Tensor
a binary vector of size (n_params) where 1s locate the subnetwork parameters within the vectorized model parameters (i.e. torch.nn.utils.parameters_to_vector(model.parameters()))
class RandomSubnetMask (model, n_params_subnet)

Subnetwork mask of parameters sampled uniformly at random.

Ancestors

  • laplace.utils.subnetmask.ScoreBasedSubnetMask
  • SubnetMask

Methods

def compute_param_scores(self, train_loader)

Inherited members

class LargestMagnitudeSubnetMask (model, n_params_subnet)

Subnetwork mask identifying the parameters with the largest magnitude.

Ancestors

  • laplace.utils.subnetmask.ScoreBasedSubnetMask
  • SubnetMask

Methods

def compute_param_scores(self, train_loader)

Inherited members

class LargestVarianceDiagLaplaceSubnetMask (model, n_params_subnet, diag_laplace_model)

Subnetwork mask identifying the parameters with the largest marginal variances (estimated using a diagonal Laplace approximation over all model parameters).

Parameters

model : torch.nn.Module
 
n_params_subnet : int
number of parameters in the subnetwork (i.e. number of top-scoring parameters to select)
diag_laplace_model : DiagLaplace
diagonal Laplace model to use for variance estimation

Ancestors

  • laplace.utils.subnetmask.ScoreBasedSubnetMask
  • SubnetMask

Methods

def compute_param_scores(self, train_loader)

Inherited members

class LargestVarianceSWAGSubnetMask (model, n_params_subnet, likelihood='classification', swag_n_snapshots=40, swag_snapshot_freq=1, swag_lr=0.01)

Subnetwork mask identifying the parameters with the largest marginal variances (estimated using diagonal SWAG over all model parameters).

Parameters

model : torch.nn.Module
 
n_params_subnet : int
number of parameters in the subnetwork (i.e. number of top-scoring parameters to select)
likelihood : str
'classification' or 'regression'
swag_n_snapshots : int
number of model snapshots to collect for SWAG
swag_snapshot_freq : int
SWAG snapshot collection frequency (in epochs)
swag_lr : float
learning rate for SWAG snapshot collection

Ancestors

  • laplace.utils.subnetmask.ScoreBasedSubnetMask
  • SubnetMask

Methods

def compute_param_scores(self, train_loader)

Inherited members

class ParamNameSubnetMask (model, parameter_names)

Subnetwork mask corresponding to the specified parameters of the neural network.

Parameters

model : torch.nn.Module
 
parameter_names : List[str]
list of names of the parameters (as in model.named_parameters()) that define the subnetwork

Ancestors

Methods

def get_subnet_mask(self, train_loader)

Get the subnetwork mask identifying the specified parameters.

Inherited members

class ModuleNameSubnetMask (model, module_names)

Subnetwork mask corresponding to the specified modules of the neural network.

Parameters

model : torch.nn.Module
 
parameter_names : List[str]
list of names of the modules (as in model.named_modules()) that define the subnetwork; the modules cannot have children, i.e. need to be leaf modules

Ancestors

Subclasses

Methods

def get_subnet_mask(self, train_loader)

Get the subnetwork mask identifying the specified modules.

Inherited members

class LastLayerSubnetMask (model, last_layer_name=None)

Subnetwork mask corresponding to the last layer of the neural network.

Parameters

model : torch.nn.Module
 
last_layer_name : str, default=None
name of the model's last layer, if None it will be determined automatically

Ancestors

Methods

def get_subnet_mask(self, train_loader)

Get the subnetwork mask identifying the last layer.

Inherited members