Module laplace.utils.matrix
Classes
class Kron (kfacs)
-
Kronecker factored approximate curvature representation for a corresponding neural network. Each element in
kfacs
is either a tuple or single matrix. A tuple represents two Kronecker factors Q, and H and a single element is just a full block Hessian approximation.Parameters
kfacs
:list[Tuple]
- each element in the list is a Tuple of two Kronecker factors Q, H or a single matrix approximating the Hessian (in case of bias, for example)
Static methods
def init_from_model(model, device)
-
Initialize Kronecker factors based on a models architecture.
Parameters
model
:nn.Module
oriterable
ofparameters, e.g. model.parameters()
device
:torch.device
Returns
kron
:Kron
Methods
def decompose(self, damping=False)
-
Eigendecompose Kronecker factors and turn into
KronDecomposed
. Parameters
damping
:bool
- use damping
Returns
kron_decomposed
:KronDecomposed
def bmm(self, W: torch.Tensor, exponent: float = 1) ‑> torch.Tensor
-
Batched matrix multiplication with the Kronecker factors. If Kron is
H
, we computeH @ W
. This is useful for computing the predictive or a regularization based on Kronecker factors as in continual learning.Parameters
W
:torch.Tensor
- matrix
(batch, classes, params)
exponent
:float
, default=1
- only can be
1
for Kron, requiresKronDecomposed
for other exponent values of the Kronecker factors.
Returns
SW
:torch.Tensor
- result
(batch, classes, params)
def logdet(self) ‑> torch.Tensor
-
Compute log determinant of the Kronecker factors and sums them up. This corresponds to the log determinant of the entire Hessian approximation.
Returns
logdet
:torch.Tensor
def diag(self) ‑> torch.Tensor
-
Extract diagonal of the entire Kronecker factorization.
Returns
diag
:torch.Tensor
def to_matrix(self) ‑> torch.Tensor
-
Make the Kronecker factorization dense by computing the kronecker product. Warning: this should only be used for testing purposes as it will allocate large amounts of memory for big architectures.
Returns
block_diag
:torch.Tensor
class KronDecomposed (eigenvectors, eigenvalues, deltas=None, damping=False)
-
Decomposed Kronecker factored approximate curvature representation for a corresponding neural network. Each matrix in
Kron
is decomposed to obtainKronDecomposed
. Front-loading decomposition allows cheap repeated computation of inverses and log determinants. In contrast toKron
, we can add scalar or layerwise scalars but we cannot add otherKron
orKronDecomposed
anymore.Parameters
eigenvectors
:list[Tuple[torch.Tensor]]
- eigenvectors corresponding to matrices in a corresponding
Kron
eigenvalues
:list[Tuple[torch.Tensor]]
- eigenvalues corresponding to matrices in a corresponding
Kron
deltas
:torch.Tensor
- addend for each group of Kronecker factors representing, for example, a prior precision
dampen
:bool
, default=False
- use dampen approximation mixing prior and Kron partially multiplicatively
Methods
def detach(self)
def logdet(self) ‑> torch.Tensor
-
Compute log determinant of the Kronecker factors and sums them up. This corresponds to the log determinant of the entire Hessian approximation. In contrast to
Kron.logdet()
, additivedeltas
corresponding to prior precisions are added.Returns
logdet
:torch.Tensor
def inv_square_form(self, W: torch.Tensor) ‑> torch.Tensor
def bmm(self, W: torch.Tensor, exponent: float = -1) ‑> torch.Tensor
-
Batched matrix multiplication with the decomposed Kronecker factors. This is useful for computing the predictive or a regularization loss. Compared to
Kron.bmm()
, a prior can be added here in form ofdeltas
and the exponent can be other than just 1. Computes H^{exponent} W.Parameters
W
:torch.Tensor
- matrix
(batch, classes, params)
exponent
:float
, default=1
Returns
SW
:torch.Tensor
- result
(batch, classes, params)
def diag(self, exponent: float = 1) ‑> torch.Tensor
-
Extract diagonal of the entire decomposed Kronecker factorization.
Parameters
exponent
:float
, default=1
- exponent of the Kronecker factorization
Returns
diag
:torch.Tensor
def to_matrix(self, exponent: float = 1) ‑> torch.Tensor
-
Make the Kronecker factorization dense by computing the kronecker product. Warning: this should only be used for testing purposes as it will allocate large amounts of memory for big architectures.
Parameters
exponent
:float
, default=1
- exponent of the Kronecker factorization
Returns
block_diag
:torch.Tensor