Module laplace.utils.utils

Functions

def get_nll(out_dist, targets)
def validate(laplace, val_loader, loss, pred_type='glm', link_approx='probit', n_samples=100, loss_with_var=False) ‑> float
def parameters_per_layer(model)

Get number of parameters per layer.

Parameters

model : torch.nn.Module
 

Returns

params_per_layer : list[int]
 
def invsqrt_precision(M)

Compute M^{-0.5} as a tridiagonal matrix.

Parameters

M : torch.Tensor
 

Returns

M_invsqrt : torch.Tensor
 
def kron(t1, t2)

Computes the Kronecker product between two tensors.

Parameters

t1 : torch.Tensor
 
t2 : torch.Tensor
 

Returns

kron_product : torch.Tensor
 
def diagonal_add_scalar(X, value)

Add scalar value value to diagonal of X.

Parameters

X : torch.Tensor
 
value : torch.Tensor or float
 

Returns

X_add_scalar : torch.Tensor
 
def symeig(M)

Symetric eigendecomposition avoiding failure cases by adding and removing jitter to the diagonal.

Parameters

M : torch.Tensor
 

Returns

L : torch.Tensor
eigenvalues
W : torch.Tensor
eigenvectors
def block_diag(blocks)

Compose block-diagonal matrix of individual blocks.

Parameters

blocks : list[torch.Tensor]
 

Returns

M : torch.Tensor
 
def expand_prior_precision(prior_prec, model)

Expand prior precision to match the shape of the model parameters.

Parameters

prior_prec : torch.Tensor 1-dimensional
prior precision
model : torch.nn.Module
torch model with parameters that are regularized by prior_prec

Returns

expanded_prior_prec : torch.Tensor
expanded prior precision has the same shape as model parameters