Module laplace.utils.matrix

Classes

class Kron (kfacs)

Kronecker factored approximate curvature representation for a corresponding neural network. Each element in kfacs is either a tuple or single matrix. A tuple represents two Kronecker factors Q, and H and a single element is just a full block Hessian approximation.

Parameters

kfacs : list[Tuple]
each element in the list is a Tuple of two Kronecker factors Q, H or a single matrix approximating the Hessian (in case of bias, for example)

Static methods

def init_from_model(model, device)

Initialize Kronecker factors based on a models architecture.

Parameters

model : torch.nn.Module
 
device : torch.device
 

Returns

kron : Kron
 

Methods

def decompose(self, damping=False)

Eigendecompose Kronecker factors and turn into KronDecomposed. Parameters


damping : bool
use damping

Returns

kron_decomposed : KronDecomposed
 
def bmm(self, W: torch.Tensor, exponent: float = 1) ‑> torch.Tensor

Batched matrix multiplication with the Kronecker factors. If Kron is H, we compute H @ W. This is useful for computing the predictive or a regularization based on Kronecker factors as in continual learning.

Parameters

W : torch.Tensor
matrix (batch, classes, params)
exponent : float, default=1
only can be 1 for Kron, requires KronDecomposed for other exponent values of the Kronecker factors.

Returns

SW : torch.Tensor
result (batch, classes, params)
def logdet(self) ‑> torch.Tensor

Compute log determinant of the Kronecker factors and sums them up. This corresponds to the log determinant of the entire Hessian approximation.

Returns

logdet : torch.Tensor
 
def diag(self) ‑> torch.Tensor

Extract diagonal of the entire Kronecker factorization.

Returns

diag : torch.Tensor
 
def to_matrix(self) ‑> torch.Tensor

Make the Kronecker factorization dense by computing the kronecker product. Warning: this should only be used for testing purposes as it will allocate large amounts of memory for big architectures.

Returns

block_diag : torch.Tensor
 
class KronDecomposed (eigenvectors, eigenvalues, deltas=None, damping=False)

Decomposed Kronecker factored approximate curvature representation for a corresponding neural network. Each matrix in Kron is decomposed to obtain KronDecomposed. Front-loading decomposition allows cheap repeated computation of inverses and log determinants. In contrast to Kron, we can add scalar or layerwise scalars but we cannot add other Kron or KronDecomposed anymore.

Parameters

eigenvectors : list[Tuple[torch.Tensor]]
eigenvectors corresponding to matrices in a corresponding Kron
eigenvalues : list[Tuple[torch.Tensor]]
eigenvalues corresponding to matrices in a corresponding Kron
deltas : torch.Tensor
addend for each group of Kronecker factors representing, for example, a prior precision
dampen : bool, default=False
use dampen approximation mixing prior and Kron partially multiplicatively

Methods

def detach(self)
def logdet(self) ‑> torch.Tensor

Compute log determinant of the Kronecker factors and sums them up. This corresponds to the log determinant of the entire Hessian approximation. In contrast to Kron.logdet(), additive deltas corresponding to prior precisions are added.

Returns

logdet : torch.Tensor
 
def inv_square_form(self, W: torch.Tensor) ‑> torch.Tensor
def bmm(self, W: torch.Tensor, exponent: float = -1) ‑> torch.Tensor

Batched matrix multiplication with the decomposed Kronecker factors. This is useful for computing the predictive or a regularization loss. Compared to Kron.bmm(), a prior can be added here in form of deltas and the exponent can be other than just 1. Computes H^{exponent} W.

Parameters

W : torch.Tensor
matrix (batch, classes, params)
exponent : float, default=1
 

Returns

SW : torch.Tensor
result (batch, classes, params)
def to_matrix(self, exponent: float = 1) ‑> torch.Tensor

Make the Kronecker factorization dense by computing the kronecker product. Warning: this should only be used for testing purposes as it will allocate large amounts of memory for big architectures.

Returns

block_diag : torch.Tensor