Module laplace.curvature

Sub-modules

laplace.curvature.asdl
laplace.curvature.backpack
laplace.curvature.curvature

Classes

class CurvatureInterface (model, likelihood, last_layer=False, subnetwork_indices=None)

Interface to access curvature for a model and corresponding likelihood. A CurvatureInterface must inherit from this baseclass and implement the necessary functions jacobians, full, kron, and diag. The interface might be extended in the future to account for other curvature structures, for example, a block-diagonal one.

Parameters

model : torch.nn.Module or FeatureExtractor
torch model (neural network)
likelihood : {'classification', 'regression'}
 
last_layer : bool, default=False
only consider curvature of last layer
subnetwork_indices : torch.Tensor, default=None
indices of the vectorized model parameters that define the subnetwork to apply the Laplace approximation over

Attributes

lossfunc : torch.nn.MSELoss or torch.nn.CrossEntropyLoss
 
factor : float
conversion factor between torch losses and base likelihoods For example, \frac{1}{2} to get to \mathcal{N}(f, 1) from MSELoss.

Subclasses

Methods

def jacobians(self, x)

Compute Jacobians \nabla_\theta f(x;\theta) at current parameter \theta.

Parameters

x : torch.Tensor
input data (batch, input_shape) on compatible device with model.

Returns

Js : torch.Tensor
Jacobians (batch, parameters, outputs)
f : torch.Tensor
output function (batch, outputs)
def last_layer_jacobians(self, x)

Compute Jacobians \nabla_{\theta_\textrm{last}} f(x;\theta_\textrm{last}) only at current last-layer parameter \theta_{\textrm{last}}.

Parameters

x : torch.Tensor
 

Returns

Js : torch.Tensor
Jacobians (batch, last-layer-parameters, outputs)
f : torch.Tensor
output function (batch, outputs)
def gradients(self, x, y)

Compute gradients \nabla_\theta \ell(f(x;\theta, y) at current parameter \theta.

Parameters

x : torch.Tensor
input data (batch, input_shape) on compatible device with model.
y : torch.Tensor
 

Returns

loss : torch.Tensor
 
Gs : torch.Tensor
gradients (batch, parameters)
def full(self, x, y, **kwargs)

Compute a dense curvature (approximation) in the form of a P \times P matrix H with respect to parameters \theta \in \mathbb{R}^P.

Parameters

x : torch.Tensor
input data (batch, input_shape)
y : torch.Tensor
labels (batch, label_shape)

Returns

loss : torch.Tensor
 
H : torch.Tensor
Hessian approximation (parameters, parameters)
def kron(self, x, y, **kwargs)

Compute a Kronecker factored curvature approximation (such as KFAC). The approximation to H takes the form of two Kronecker factors Q, H, i.e., H \approx Q \otimes H for each Module in the neural network permitting such curvature. Q is quadratic in the input-dimension of a module p_{in} \times p_{in} and H in the output-dimension p_{out} \times p_{out}.

Parameters

x : torch.Tensor
input data (batch, input_shape)
y : torch.Tensor
labels (batch, label_shape)

Returns

loss : torch.Tensor
 
H : Kron
Kronecker factored Hessian approximation.
def diag(self, x, y, **kwargs)

Compute a diagonal Hessian approximation to H and is represented as a vector of the dimensionality of parameters \theta.

Parameters

x : torch.Tensor
input data (batch, input_shape)
y : torch.Tensor
labels (batch, label_shape)

Returns

loss : torch.Tensor
 
H : torch.Tensor
vector representing the diagonal of H
class GGNInterface (model, likelihood, last_layer=False, subnetwork_indices=None, stochastic=False)

Generalized Gauss-Newton or Fisher Curvature Interface. The GGN is equal to the Fisher information for the available likelihoods. In addition to CurvatureInterface, methods for Jacobians are required by subclasses.

Parameters

model : torch.nn.Module or FeatureExtractor
torch model (neural network)
likelihood : {'classification', 'regression'}
 
last_layer : bool, default=False
only consider curvature of last layer
subnetwork_indices : torch.Tensor, default=None
indices of the vectorized model parameters that define the subnetwork to apply the Laplace approximation over
stochastic : bool, default=False
Fisher if stochastic else GGN

Ancestors

Subclasses

Methods

def full(self, x, y, **kwargs)

Compute the full GGN P \times P matrix as Hessian approximation H_{ggn} with respect to parameters \theta \in \mathbb{R}^P. For last-layer, reduced to \theta_{last}

Parameters

x : torch.Tensor
input data (batch, input_shape)
y : torch.Tensor
labels (batch, label_shape)

Returns

loss : torch.Tensor
 
H_ggn : torch.Tensor
GGN (parameters, parameters)

Inherited members

class EFInterface (model, likelihood, last_layer=False, subnetwork_indices=None)

Interface for Empirical Fisher as Hessian approximation. In addition to CurvatureInterface, methods for gradients are required by subclasses.

Parameters

model : torch.nn.Module or FeatureExtractor
torch model (neural network)
likelihood : {'classification', 'regression'}
 
last_layer : bool, default=False
only consider curvature of last layer
subnetwork_indices : torch.Tensor, default=None
indices of the vectorized model parameters that define the subnetwork to apply the Laplace approximation over

Attributes

lossfunc : torch.nn.MSELoss or torch.nn.CrossEntropyLoss
 
factor : float
conversion factor between torch losses and base likelihoods For example, \frac{1}{2} to get to \mathcal{N}(f, 1) from MSELoss.

Ancestors

Subclasses

Methods

def full(self, x, y, **kwargs)

Compute the full EF P \times P matrix as Hessian approximation H_{ef} with respect to parameters \theta \in \mathbb{R}^P. For last-layer, reduced to \theta_{last}

Parameters

x : torch.Tensor
input data (batch, input_shape)
y : torch.Tensor
labels (batch, label_shape)

Returns

loss : torch.Tensor
 
H_ef : torch.Tensor
EF (parameters, parameters)

Inherited members

class BackPackInterface (model, likelihood, last_layer=False, subnetwork_indices=None)

Interface for Backpack backend.

Ancestors

Subclasses

Methods

def jacobians(self, x)

Compute Jacobians \nabla_{\theta} f(x;\theta) at current parameter \theta using backpack's BatchGrad per output dimension.

Parameters

x : torch.Tensor
input data (batch, input_shape) on compatible device with model.

Returns

Js : torch.Tensor
Jacobians (batch, parameters, outputs)
f : torch.Tensor
output function (batch, outputs)
def gradients(self, x, y)

Compute gradients \nabla_\theta \ell(f(x;\theta, y) at current parameter \theta using Backpack's BatchGrad.

Parameters

x : torch.Tensor
input data (batch, input_shape) on compatible device with model.
y : torch.Tensor
 

Returns

loss : torch.Tensor
 
Gs : torch.Tensor
gradients (batch, parameters)

Inherited members

class BackPackGGN (model, likelihood, last_layer=False, subnetwork_indices=None, stochastic=False)

Implementation of the GGNInterface using Backpack.

Ancestors

Inherited members

class BackPackEF (model, likelihood, last_layer=False, subnetwork_indices=None)

Implementation of EFInterface using Backpack.

Ancestors

Inherited members

class AsdlInterface (model, likelihood, last_layer=False, subnetwork_indices=None)

Interface for asdfghjkl backend.

Ancestors

Subclasses

Methods

def jacobians(self, x)

Compute Jacobians \nabla_\theta f(x;\theta) at current parameter \theta using asdfghjkl's gradient per output dimension.

Parameters

x : torch.Tensor
input data (batch, input_shape) on compatible device with model.

Returns

Js : torch.Tensor
Jacobians (batch, parameters, outputs)
f : torch.Tensor
output function (batch, outputs)
def gradients(self, x, y)

Compute gradients \nabla_\theta \ell(f(x;\theta, y) at current parameter \theta using asdfghjkl's backend.

Parameters

x : torch.Tensor
input data (batch, input_shape) on compatible device with model.
y : torch.Tensor
 

Returns

loss : torch.Tensor
 
Gs : torch.Tensor
gradients (batch, parameters)

Inherited members

class AsdlGGN (model, likelihood, last_layer=False, subnetwork_indices=None, stochastic=False)

Implementation of the GGNInterface using asdfghjkl.

Ancestors

Inherited members

class AsdlEF (model, likelihood, last_layer=False)

Implementation of the EFInterface using asdfghjkl.

Ancestors

Inherited members

class AsdlHessian (model, likelihood, last_layer=False, low_rank=10)

Interface for asdfghjkl backend.

Ancestors

Methods

def eig_lowrank(self, data_loader)

Inherited members