Module laplace.utils
Sub-modules
laplace.utils.enums
laplace.utils.feature_extractor
laplace.utils.matrix
laplace.utils.metrics
laplace.utils.subnetmask
laplace.utils.swag
laplace.utils.utils
Functions
def get_nll(out_dist: torch.Tensor, targets: torch.Tensor) ‑> torch.Tensor
def validate(laplace: BaseLaplace, val_loader: DataLoader, loss: torchmetrics.Metric | Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor], pred_type: PredType | str = PredType.GLM, link_approx: LinkApprox | str = LinkApprox.PROBIT, n_samples: int = 100, dict_key_y: str = 'labels')
def parameters_per_layer(model: nn.Module) ‑> list[int]
-
Get number of parameters per layer.
Parameters
model
:torch.nn.Module
Returns
params_per_layer
:list[int]
def invsqrt_precision(M: torch.Tensor) ‑> torch.Tensor
-
Compute
M^{-0.5}
as a tridiagonal matrix.Parameters
M
:torch.Tensor
Returns
M_invsqrt
:torch.Tensor
def kron(t1: torch.Tensor, t2: torch.Tensor) ‑> torch.Tensor
-
Computes the Kronecker product between two tensors.
Parameters
t1
:torch.Tensor
t2
:torch.Tensor
Returns
kron_product
:torch.Tensor
def diagonal_add_scalar(X: torch.Tensor, value: torch.Tensor) ‑> torch.Tensor
-
Add scalar value
value
to diagonal ofX
.Parameters
X
:torch.Tensor
value
:torch.Tensor
orfloat
Returns
X_add_scalar
:torch.Tensor
def symeig(M: torch.Tensor) ‑> tuple[torch.Tensor, torch.Tensor]
-
Symetric eigendecomposition avoiding failure cases by adding and removing jitter to the diagonal.
Parameters
M
:torch.Tensor
Returns
L
:torch.Tensor
- eigenvalues
W
:torch.Tensor
- eigenvectors
def block_diag(blocks: list[torch.Tensor]) ‑> torch.Tensor
-
Compose block-diagonal matrix of individual blocks.
Parameters
blocks
:list[torch.Tensor]
Returns
M
:torch.Tensor
def normal_samples(mean: torch.Tensor, var: torch.Tensor, n_samples: int, generator: torch.Generator | None = None)
-
Produce samples from a batch of Normal distributions either parameterized by a diagonal or full covariance given by
var
.Parameters
mean
:torch.Tensor
(batch_size, output_dim)
var
:torch.Tensor
- (co)variance of the Normal distribution
(batch_size, output_dim, output_dim)
or(batch_size, output_dim)
generator
:torch.Generator
- random number generator
def _is_batchnorm(module: nn.Module) ‑> bool
def _is_valid_scalar(scalar: float | int | torch.Tensor)
def expand_prior_precision(prior_prec: torch.Tensor, model: nn.Module) ‑> torch.Tensor
-
Expand prior precision to match the shape of the model parameters.
Parameters
prior_prec
:torch.Tensor 1-dimensional
- prior precision
model
:torch.nn.Module
- torch model with parameters that are regularized by prior_prec
Returns
expanded_prior_prec
:torch.Tensor
- expanded prior precision has the same shape as model parameters
def fix_prior_prec_structure(prior_prec_init: torch.Tensor, prior_structure: PriorStructure | str, n_layers: int, n_params: int, device: torch.device)
-
Create a tensor of prior precision with the correct shape, depending on the choice of the prior structure type.
Parameters
prior_prec_init
:torch.Tensor
- the initial prior precision tensor (could be scalar)
prior_structure
:PriorStructure | str
- the choice of the prior structure type
n_layers
:int
n_params
:int
device
:torch.device
Returns
correct_prior_precision
:torch.Tensor
def fit_diagonal_swag_var(model: nn.Module, train_loader: DataLoader, criterion: nn.CrossEntropyLoss | nn.MSELoss, n_snapshots_total: int = 40, snapshot_freq: int = 1, lr: float = 0.01, momentum: float = 0.9, weight_decay: float = 0.0003, min_var: float = 1e-30)
-
Fit diagonal SWAG [1], which estimates marginal variances of model parameters by computing the first and second moment of SGD iterates with a large learning rate.
Implementation partly adapted from: - https://github.com/wjmaddox/swa_gaussian/blob/master/swag/posteriors/swag.py - https://github.com/wjmaddox/swa_gaussian/blob/master/experiments/train/run_swag.py
References
[1] Maddox, W., Garipov, T., Izmailov, P., Vetrov, D., Wilson, AG. A Simple Baseline for Bayesian Uncertainty in Deep Learning. NeurIPS 2019.
Parameters
model
:torch.nn.Module
train_loader
:torch.data.utils.DataLoader
- training data loader to use for snapshot collection
criterion
:torch.nn.CrossEntropyLoss
ortorch.nn.MSELoss
- loss function to use for snapshot collection
n_snapshots_total
:int
- total number of model snapshots to collect
snapshot_freq
:int
- snapshot collection frequency (in epochs)
lr
:float
- SGD learning rate for collecting snapshots
momentum
:float
- SGD momentum
weight_decay
:float
- SGD weight decay
min_var
:float
- minimum parameter variance to clamp to (for numerical stability)
Returns
param_variances
:torch.Tensor
- vector of marginal variances for each model parameter
Classes
class SoDSampler (N, M, seed: int = 0)
-
Base class for all Samplers.
Every Sampler subclass has to provide an :meth:
__iter__
method, providing a way to iterate over indices or lists of indices (batches) of dataset elements, and may provide a :meth:__len__
method that returns the length of the returned iterators.Args
data_source
:Dataset
- This argument is not used and will be removed in 2.2.0. You may still have custom implementation that utilizes it.
Example
>>> # xdoctest: +SKIP >>> class AccedingSequenceLengthSampler(Sampler[int]): >>> def __init__(self, data: List[str]) -> None: >>> self.data = data >>> >>> def __len__(self) -> int: >>> return len(self.data) >>> >>> def __iter__(self) -> Iterator[int]: >>> sizes = torch.tensor([len(x) for x in self.data]) >>> yield from torch.argsort(sizes).tolist() >>> >>> class AccedingSequenceLengthBatchSampler(Sampler[List[int]]): >>> def __init__(self, data: List[str], batch_size: int) -> None: >>> self.data = data >>> self.batch_size = batch_size >>> >>> def __len__(self) -> int: >>> return (len(self.data) + self.batch_size - 1) // self.batch_size >>> >>> def __iter__(self) -> Iterator[List[int]]: >>> sizes = torch.tensor([len(x) for x in self.data]) >>> for batch in torch.chunk(torch.argsort(sizes), len(self)): >>> yield batch.tolist()
Note: The :meth:
__len__
method isn't strictly required by:class:
~torch.utils.data.DataLoader
, but is expected in any calculation involving the length of a :class:~torch.utils.data.DataLoader
.Ancestors
- torch.utils.data.sampler.Sampler
- typing.Generic
class FeatureExtractor (model: nn.Module, last_layer_name: str | None = None, enable_backprop: bool = False, feature_reduction: FeatureReduction | str | None = None)
-
Feature extractor for a PyTorch neural network. A wrapper which can return the output of the penultimate layer in addition to the output of the last layer for each forward pass. If the name of the last layer is not known, it can determine it automatically. It assumes that the last layer is linear and that for every forward pass the last layer is the same. If the name of the last layer is known, it can be passed as a parameter at initilization; this is the safest way to use this class. Based on https://gist.github.com/fkodom/27ed045c9051a39102e8bcf4ce31df76.
Parameters
model
:torch.nn.Module
- PyTorch model
last_layer_name
:str
, default=None
- if the name of the last layer is already known, otherwise it will be determined automatically.
enable_backprop
:bool
, default=False
- whether to enable backprop through the feature extactor to get the gradients of the inputs. Useful for e.g. Bayesian optimization.
feature_reduction
:FeatureReduction
orstr
, default=None
- when the last-layer
features
is a tensor of dim >= 3, this tells how to reduce it into a dim-2 tensor. E.g. in LLMs for non-language modeling problems, the penultultimate output is a tensor of shape(batch_size, seq_len, embd_dim)
. But the last layer maps(batch_size, embd_dim)
to(batch_size, n_classes)
. Note: Make sure that this option faithfully reflects the reduction in the model definition. When inputting a string, available options are{'pick_first', 'pick_last', 'average'}
.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Ancestors
- torch.nn.modules.module.Module
Class variables
var dump_patches : bool
var training : bool
var call_super_init : bool
Methods
def forward(self, x: torch.Tensor | MutableMapping[str, torch.Tensor | Any]) ‑> Callable[..., Any]
-
Forward pass. If the last layer is not known yet, it will be determined when this function is called for the first time.
Parameters
x
:torch.Tensor
ora dict-like object containing the input tensors
- one batch of data to use as input for the forward pass
def forward_with_features(self, x: torch.Tensor | MutableMapping[str, torch.Tensor | Any])
-
Forward pass which returns the output of the penultimate layer along with the output of the last layer. If the last layer is not known yet, it will be determined when this function is called for the first time.
Parameters
x
:torch.Tensor
ora dict-like object containing the input tensors
- one batch of data to use as input for the forward pass
def set_last_layer(self, last_layer_name: str) ‑> None
-
Set the last layer of the model by its name. This sets the forward hook to get the output of the penultimate layer.
Parameters
last_layer_name
:str
- the name of the last layer (fixed in
model.named_modules()
).
def find_last_layer(self, x: torch.Tensor | MutableMapping[str, torch.Tensor | Any])
-
Automatically determines the last layer of the model with one forward pass. It assumes that the last layer is the same for every forward pass and that it is an instance of
torch.nn.Linear
. Might not work with every architecture, but is tested with all PyTorch torchvision classification models (besides SqueezeNet, which has no linear last layer).Parameters
x
:torch.Tensor
ordict-like object containing the input tensors
- one batch of data to use as input for the forward pass
class Kron (kfacs: list[tuple[torch.Tensor] | torch.Tensor])
-
Kronecker factored approximate curvature representation for a corresponding neural network. Each element in
kfacs
is either a tuple or single matrix. A tuple represents two Kronecker factors Q, and H and a single element is just a full block Hessian approximation.Parameters
kfacs
:list[Iterable[torch.Tensor] | torch.Tensor]
- each element in the list is a tuple of two Kronecker factors Q, H or a single matrix approximating the Hessian (in case of bias, for example)
Static methods
def init_from_model(model: nn.Module | Iterable[nn.Parameter], device: torch.device)
-
Initialize Kronecker factors based on a models architecture.
Parameters
model
:nn.Module
oriterable
ofparameters, e.g. model.parameters()
device
:torch.device
Returns
kron
:Kron
Methods
def decompose(self, damping: bool = False) ‑> KronDecomposed
-
Eigendecompose Kronecker factors and turn into
KronDecomposed
. Parameters
damping
:bool
- use damping
Returns
kron_decomposed
:KronDecomposed
def bmm(self, W: torch.Tensor, exponent: float = 1) ‑> torch.Tensor
-
Batched matrix multiplication with the Kronecker factors. If Kron is
H
, we computeH @ W
. This is useful for computing the predictive or a regularization based on Kronecker factors as in continual learning.Parameters
W
:torch.Tensor
- matrix
(batch, classes, params)
exponent
:float
, default=1
- only can be
1
for Kron, requiresKronDecomposed
for other exponent values of the Kronecker factors.
Returns
SW
:torch.Tensor
- result
(batch, classes, params)
def logdet(self) ‑> torch.Tensor
-
Compute log determinant of the Kronecker factors and sums them up. This corresponds to the log determinant of the entire Hessian approximation.
Returns
logdet
:torch.Tensor
def diag(self) ‑> torch.Tensor
-
Extract diagonal of the entire Kronecker factorization.
Returns
diag
:torch.Tensor
def to_matrix(self) ‑> torch.Tensor
-
Make the Kronecker factorization dense by computing the kronecker product. Warning: this should only be used for testing purposes as it will allocate large amounts of memory for big architectures.
Returns
block_diag
:torch.Tensor
class KronDecomposed (eigenvectors: list[tuple[torch.Tensor]], eigenvalues: list[tuple[torch.Tensor]], deltas: torch.Tensor | None = None, damping: bool = False)
-
Decomposed Kronecker factored approximate curvature representation for a corresponding neural network. Each matrix in
Kron
is decomposed to obtainKronDecomposed
. Front-loading decomposition allows cheap repeated computation of inverses and log determinants. In contrast toKron
, we can add scalar or layerwise scalars but we cannot add otherKron
orKronDecomposed
anymore.Parameters
eigenvectors
:list[Tuple[torch.Tensor]]
- eigenvectors corresponding to matrices in a corresponding
Kron
eigenvalues
:list[Tuple[torch.Tensor]]
- eigenvalues corresponding to matrices in a corresponding
Kron
deltas
:torch.Tensor
- addend for each group of Kronecker factors representing, for example, a prior precision
dampen
:bool
, default=False
- use dampen approximation mixing prior and Kron partially multiplicatively
Methods
def detach(self) ‑> KronDecomposed
def logdet(self) ‑> torch.Tensor
-
Compute log determinant of the Kronecker factors and sums them up. This corresponds to the log determinant of the entire Hessian approximation. In contrast to
Kron.logdet()
, additivedeltas
corresponding to prior precisions are added.Returns
logdet
:torch.Tensor
def inv_square_form(self, W: torch.Tensor) ‑> torch.Tensor
def bmm(self, W: torch.Tensor, exponent: float = -1) ‑> torch.Tensor
-
Batched matrix multiplication with the decomposed Kronecker factors. This is useful for computing the predictive or a regularization loss. Compared to
Kron.bmm()
, a prior can be added here in form ofdeltas
and the exponent can be other than just 1. Computes H^{exponent} W.Parameters
W
:torch.Tensor
- matrix
(batch, classes, params)
exponent
:float
, default=1
Returns
SW
:torch.Tensor
- result
(batch, classes, params)
def diag(self, exponent: float = 1) ‑> torch.Tensor
-
Extract diagonal of the entire decomposed Kronecker factorization.
Parameters
exponent
:float
, default=1
- exponent of the Kronecker factorization
Returns
diag
:torch.Tensor
def to_matrix(self, exponent: float = 1) ‑> torch.Tensor
-
Make the Kronecker factorization dense by computing the kronecker product. Warning: this should only be used for testing purposes as it will allocate large amounts of memory for big architectures.
Parameters
exponent
:float
, default=1
- exponent of the Kronecker factorization
Returns
block_diag
:torch.Tensor
class SubnetMask (model: nn.Module)
-
Baseclass for all subnetwork masks in this library (for subnetwork Laplace).
Parameters
model
:torch.nn.Module
Subclasses
- ModuleNameSubnetMask
- ParamNameSubnetMask
- laplace.utils.subnetmask.ScoreBasedSubnetMask
Instance variables
prop indices : torch.LongTensor
prop n_params_subnet : int
Methods
def convert_subnet_mask_to_indices(self, subnet_mask: torch.Tensor) ‑> torch.LongTensor
-
Converts a subnetwork mask into subnetwork indices.
Parameters
subnet_mask
:torch.Tensor
- a binary vector of size (n_params) where 1s locate the subnetwork parameters
within the vectorized model parameters
(i.e.
torch.nn.utils.parameters_to_vector(model.parameters())
)
Returns
subnet_mask_indices
:torch.LongTensor
- a vector of indices of the vectorized model parameters
(i.e.
torch.nn.utils.parameters_to_vector(model.parameters())
) that define the subnetwork
def select(self, train_loader: DataLoader | None = None)
-
Select the subnetwork mask.
Parameters
train_loader
:torch.data.utils.DataLoader
, default=None
- each iterate is a training batch (X, y);
train_loader.dataset
needs to be set to access N, size of the data set
Returns
subnet_mask_indices
:torch.LongTensor
- a vector of indices of the vectorized model parameters
(i.e.
torch.nn.utils.parameters_to_vector(model.parameters())
) that define the subnetwork
def get_subnet_mask(self, train_loader: DataLoader) ‑> torch.Tensor
-
Get the subnetwork mask.
Parameters
train_loader
:torch.data.utils.DataLoader
- each iterate is a training batch (X, y);
train_loader.dataset
needs to be set to access N, size of the data set
Returns
subnet_mask
:torch.Tensor
- a binary vector of size (n_params) where 1s locate the subnetwork parameters
within the vectorized model parameters
(i.e.
torch.nn.utils.parameters_to_vector(model.parameters())
)
class RandomSubnetMask (model: nn.Module, n_params_subnet: int)
-
Subnetwork mask of parameters sampled uniformly at random.
Ancestors
- laplace.utils.subnetmask.ScoreBasedSubnetMask
- SubnetMask
Methods
def compute_param_scores(self, train_loader: DataLoader) ‑> torch.Tensor
Inherited members
class LargestMagnitudeSubnetMask (model: nn.Module, n_params_subnet: int)
-
Subnetwork mask identifying the parameters with the largest magnitude.
Ancestors
- laplace.utils.subnetmask.ScoreBasedSubnetMask
- SubnetMask
Methods
def compute_param_scores(self, train_loader: DataLoader) ‑> torch.Tensor
Inherited members
class LargestVarianceDiagLaplaceSubnetMask (model: nn.Module, n_params_subnet: int, diag_laplace_model: DiagLaplace)
-
Subnetwork mask identifying the parameters with the largest marginal variances (estimated using a diagonal Laplace approximation over all model parameters).
Parameters
model
:torch.nn.Module
n_params_subnet
:int
- number of parameters in the subnetwork (i.e. number of top-scoring parameters to select)
diag_laplace_model
:DiagLaplace
- diagonal Laplace model to use for variance estimation
Ancestors
- laplace.utils.subnetmask.ScoreBasedSubnetMask
- SubnetMask
Methods
def compute_param_scores(self, train_loader: DataLoader) ‑> torch.Tensor
Inherited members
class LargestVarianceSWAGSubnetMask (model: nn.Module, n_params_subnet: int, likelihood: Likelihood | str = Likelihood.CLASSIFICATION, swag_n_snapshots: int = 40, swag_snapshot_freq: int = 1, swag_lr: float = 0.01)
-
Subnetwork mask identifying the parameters with the largest marginal variances (estimated using diagonal SWAG over all model parameters).
Parameters
model
:torch.nn.Module
n_params_subnet
:int
- number of parameters in the subnetwork (i.e. number of top-scoring parameters to select)
likelihood
:str
- 'classification' or 'regression'
swag_n_snapshots
:int
- number of model snapshots to collect for SWAG
swag_snapshot_freq
:int
- SWAG snapshot collection frequency (in epochs)
swag_lr
:float
- learning rate for SWAG snapshot collection
Ancestors
- laplace.utils.subnetmask.ScoreBasedSubnetMask
- SubnetMask
Methods
def compute_param_scores(self, train_loader: DataLoader) ‑> torch.Tensor
Inherited members
class ParamNameSubnetMask (model: nn.Module, parameter_names: list[str])
-
Subnetwork mask corresponding to the specified parameters of the neural network.
Parameters
model
:torch.nn.Module
parameter_names
:List[str]
- list of names of the parameters (as in
model.named_parameters()
) that define the subnetwork
Ancestors
Methods
def get_subnet_mask(self, train_loader: DataLoader) ‑> torch.Tensor
-
Get the subnetwork mask identifying the specified parameters.
Inherited members
class ModuleNameSubnetMask (model: nn.Module, module_names: list[str])
-
Subnetwork mask corresponding to the specified modules of the neural network.
Parameters
model
:torch.nn.Module
parameter_names
:List[str]
- list of names of the modules (as in
model.named_modules()
) that define the subnetwork; the modules cannot have children, i.e. need to be leaf modules
Ancestors
Subclasses
Methods
def get_subnet_mask(self, train_loader: DataLoader) ‑> torch.Tensor
-
Get the subnetwork mask identifying the specified modules.
Inherited members
class LastLayerSubnetMask (model: nn.Module, last_layer_name: str | None = None)
-
Subnetwork mask corresponding to the last layer of the neural network.
Parameters
model
:torch.nn.Module
last_layer_name
:str
, default=None
- name of the model's last layer, if None it will be determined automatically
Ancestors
Methods
def get_subnet_mask(self, train_loader: DataLoader) ‑> torch.Tensor
-
Get the subnetwork mask identifying the last layer.
Inherited members
class RunningNLLMetric (ignore_index: int = -100)
-
NLL metrics that
Parameters
ignore_index
:int
, default= -100
- which class label to ignore when computing the NLL loss
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Ancestors
- torchmetrics.metric.Metric
- torch.nn.modules.module.Module
- abc.ABC
Class variables
var is_differentiable : Optional[bool]
var higher_is_better : Optional[bool]
var full_state_update : Optional[bool]
var plot_lower_bound : Optional[float]
var plot_upper_bound : Optional[float]
var plot_legend_name : Optional[str]
Methods
def update(self, probs: torch.Tensor, targets: torch.Tensor) ‑> None
-
Parameters
probs
:torch.Tensor
- probability tensor of shape (…, n_classes)
targets
:torch.Tensor
- integer tensor of shape (…)
def compute(self) ‑> torch.Tensor
-
Override this method to compute the final metric value.
This method will automatically synchronize state variables when running in distributed backend.
class SubsetOfWeights (value, names=None, *, module=None, qualname=None, type=None, start=1)
-
An enumeration.
Ancestors
- builtins.str
- enum.Enum
Class variables
var ALL
var LAST_LAYER
var SUBNETWORK
class HessianStructure (value, names=None, *, module=None, qualname=None, type=None, start=1)
-
An enumeration.
Ancestors
- builtins.str
- enum.Enum
Class variables
var FULL
var KRON
var DIAG
var LOWRANK
var GP
class Likelihood (value, names=None, *, module=None, qualname=None, type=None, start=1)
-
An enumeration.
Ancestors
- builtins.str
- enum.Enum
Class variables
var REGRESSION
var CLASSIFICATION
var REWARD_MODELING
class PredType (value, names=None, *, module=None, qualname=None, type=None, start=1)
-
An enumeration.
Ancestors
- builtins.str
- enum.Enum
Class variables
var GLM
var NN
var GP
class LinkApprox (value, names=None, *, module=None, qualname=None, type=None, start=1)
-
An enumeration.
Ancestors
- builtins.str
- enum.Enum
Class variables
var MC
var PROBIT
var BRIDGE
var BRIDGE_NORM
class TuningMethod (value, names=None, *, module=None, qualname=None, type=None, start=1)
-
An enumeration.
Ancestors
- builtins.str
- enum.Enum
Class variables
var MARGLIK
var GRIDSEARCH
class PriorStructure (value, names=None, *, module=None, qualname=None, type=None, start=1)
-
An enumeration.
Ancestors
- builtins.str
- enum.Enum
Class variables
var SCALAR
var DIAG
var LAYERWISE