Module laplace.curvature.curvature

Classes

class CurvatureInterface (model, likelihood, last_layer=False, subnetwork_indices=None)

Interface to access curvature for a model and corresponding likelihood. A CurvatureInterface must inherit from this baseclass and implement the necessary functions jacobians, full, kron, and diag. The interface might be extended in the future to account for other curvature structures, for example, a block-diagonal one.

Parameters

model : torch.nn.Module or FeatureExtractor
torch model (neural network)
likelihood : {'classification', 'regression'}
 
last_layer : bool, default=False
only consider curvature of last layer
subnetwork_indices : torch.Tensor, default=None
indices of the vectorized model parameters that define the subnetwork to apply the Laplace approximation over

Attributes

lossfunc : torch.nn.MSELoss or torch.nn.CrossEntropyLoss
 
factor : float
conversion factor between torch losses and base likelihoods For example, \frac{1}{2} to get to \mathcal{N}(f, 1) from MSELoss.

Subclasses

Methods

def jacobians(self, x, enable_backprop=False)

Compute Jacobians \nabla_{\theta} f(x;\theta) at current parameter \theta, via torch.func.

Parameters

x : torch.Tensor
input data (batch, input_shape) on compatible device with model.
enable_backprop : bool, default = False
whether to enable backprop through the Js and f w.r.t. x

Returns

Js : torch.Tensor
Jacobians (batch, parameters, outputs)
f : torch.Tensor
output function (batch, outputs)
def functorch_jacobians(self, x, enable_backprop=False)

Compute Jacobians \nabla_\theta f(x;\theta) at current parameter \theta.

Parameters

x : torch.Tensor
input data (batch, input_shape) on compatible device with model.
enable_backprop : bool, default = False
whether to enable backprop through the Js and f w.r.t. x

Returns

Js : torch.Tensor
Jacobians (batch, parameters, outputs)
f : torch.Tensor
output function (batch, outputs)
def last_layer_jacobians(self, x, enable_backprop=False)

Compute Jacobians \nabla_{\theta_\textrm{last}} f(x;\theta_\textrm{last}) only at current last-layer parameter \theta_{\textrm{last}}.

Parameters

x : torch.Tensor
 
enable_backprop : bool, default=False
 

Returns

Js : torch.Tensor
Jacobians (batch, last-layer-parameters, outputs)
f : torch.Tensor
output function (batch, outputs)
def gradients(self, x, y)

Compute batch gradients \nabla_\theta \ell(f(x;\theta, y) at current parameter \theta.

Parameters

x : torch.Tensor
input data (batch, input_shape) on compatible device with model.
y : torch.Tensor
 

Returns

Gs : torch.Tensor
gradients (batch, parameters)
loss : torch.Tensor
 
def full(self, x, y, **kwargs)

Compute a dense curvature (approximation) in the form of a P \times P matrix H with respect to parameters \theta \in \mathbb{R}^P.

Parameters

x : torch.Tensor
input data (batch, input_shape)
y : torch.Tensor
labels (batch, label_shape)

Returns

loss : torch.Tensor
 
H : torch.Tensor
Hessian approximation (parameters, parameters)
def kron(self, x, y, **kwargs)

Compute a Kronecker factored curvature approximation (such as KFAC). The approximation to H takes the form of two Kronecker factors Q, H, i.e., H \approx Q \otimes H for each Module in the neural network permitting such curvature. Q is quadratic in the input-dimension of a module p_{in} \times p_{in} and H in the output-dimension p_{out} \times p_{out}.

Parameters

x : torch.Tensor
input data (batch, input_shape)
y : torch.Tensor
labels (batch, label_shape)

Returns

loss : torch.Tensor
 
H : Kron
Kronecker factored Hessian approximation.
def diag(self, x, y, **kwargs)

Compute a diagonal Hessian approximation to H and is represented as a vector of the dimensionality of parameters \theta.

Parameters

x : torch.Tensor
input data (batch, input_shape)
y : torch.Tensor
labels (batch, label_shape)

Returns

loss : torch.Tensor
 
H : torch.Tensor
vector representing the diagonal of H
class GGNInterface (model, likelihood, last_layer=False, subnetwork_indices=None, stochastic=False, num_samples=1)

Generalized Gauss-Newton or Fisher Curvature Interface. The GGN is equal to the Fisher information for the available likelihoods. In addition to CurvatureInterface, methods for Jacobians are required by subclasses.

Parameters

model : torch.nn.Module or FeatureExtractor
torch model (neural network)
likelihood : {'classification', 'regression'}
 
last_layer : bool, default=False
only consider curvature of last layer
subnetwork_indices : torch.Tensor, default=None
indices of the vectorized model parameters that define the subnetwork to apply the Laplace approximation over
stochastic : bool, default=False
Fisher if stochastic else GGN
num_samples : int, default=100
Number of samples used to approximate the stochastic Fisher

Ancestors

Subclasses

Methods

def full(self, x, y, **kwargs)

Compute the full GGN P \times P matrix as Hessian approximation H_{ggn} with respect to parameters \theta \in \mathbb{R}^P. For last-layer, reduced to \theta_{last}

Parameters

x : torch.Tensor
input data (batch, input_shape)
y : torch.Tensor
labels (batch, label_shape)

Returns

loss : torch.Tensor
 
H : torch.Tensor
GGN (parameters, parameters)

Inherited members

class EFInterface (model, likelihood, last_layer=False, subnetwork_indices=None)

Interface for Empirical Fisher as Hessian approximation. In addition to CurvatureInterface, methods for gradients are required by subclasses.

Parameters

model : torch.nn.Module or FeatureExtractor
torch model (neural network)
likelihood : {'classification', 'regression'}
 
last_layer : bool, default=False
only consider curvature of last layer
subnetwork_indices : torch.Tensor, default=None
indices of the vectorized model parameters that define the subnetwork to apply the Laplace approximation over

Attributes

lossfunc : torch.nn.MSELoss or torch.nn.CrossEntropyLoss
 
factor : float
conversion factor between torch losses and base likelihoods For example, \frac{1}{2} to get to \mathcal{N}(f, 1) from MSELoss.

Ancestors

Subclasses

Methods

def full(self, x, y, **kwargs)

Compute the full EF P \times P matrix as Hessian approximation H_{ef} with respect to parameters \theta \in \mathbb{R}^P. For last-layer, reduced to \theta_{last}

Parameters

x : torch.Tensor
input data (batch, input_shape)
y : torch.Tensor
labels (batch, label_shape)

Returns

loss : torch.Tensor
 
H_ef : torch.Tensor
EF (parameters, parameters)

Inherited members