Module laplace.lllaplace

Classes

class LLLaplace (model: nn.Module, likelihood: Likelihood | str, sigma_noise: float | torch.Tensor = 1.0, prior_precision: float | torch.Tensor = 1.0, prior_mean: float | torch.Tensor = 0.0, temperature: float = 1.0, enable_backprop: bool = False, feature_reduction: FeatureReduction | str | None = None, dict_key_x: str = 'inputs_id', dict_key_y: str = 'labels', backend: type[CurvatureInterface] | None = None, last_layer_name: str | None = None, backend_kwargs: dict[str, Any] | None = None, asdl_fisher_kwargs: dict[str, Any] | None = None)

Baseclass for all last-layer Laplace approximations in this library. Subclasses specify the structure of the Hessian approximation. See BaseLaplace for the full interface.

A Laplace approximation is represented by a MAP which is given by the model parameter and a posterior precision or covariance specifying a Gaussian distribution \mathcal{N}(\theta_{MAP}, P^{-1}). Here, only the parameters of the last layer of the neural network are treated probabilistically. The goal of this class is to compute the posterior precision P which sums as P = \sum_{n=1}^N \nabla^2_\theta \log p(\mathcal{D}_n \mid \theta) \vert_{\theta_{MAP}} + \nabla^2_\theta \log p(\theta) \vert_{\theta_{MAP}}. Every subclass implements different approximations to the log likelihood Hessians, for example, a diagonal one. The prior is assumed to be Gaussian and therefore we have a simple form for \nabla^2_\theta \log p(\theta) \vert_{\theta_{MAP}} = P_0 . In particular, we assume a scalar or diagonal prior precision so that in all cases P_0 = \textrm{diag}(p_0) and the structure of p_0 can be varied.

Parameters

model : torch.nn.Module or FeatureExtractor
 
likelihood : Likelihood or {'classification', 'regression'}
determines the log likelihood Hessian approximation
sigma_noise : torch.Tensor or float, default=1
observation noise for the regression setting; must be 1 for classification
prior_precision : torch.Tensor or float, default=1
prior precision of a Gaussian prior (= weight decay); can be scalar, per-layer, or diagonal in the most general case
prior_mean : torch.Tensor or float, default=0
prior mean of a Gaussian prior, useful for continual learning
temperature : float, default=1
temperature of the likelihood; lower temperature leads to more concentrated posterior and vice versa.
enable_backprop : bool, default=False
whether to enable backprop to the input x through the Laplace predictive. Useful for e.g. Bayesian optimization.
feature_reduction : FeatureReduction or str, optional, default=None
when the last-layer features is a tensor of dim >= 3, this tells how to reduce it into a dim-2 tensor. E.g. in LLMs for non-language modeling problems, the penultultimate output is a tensor of shape (batch_size, seq_len, embd_dim). But the last layer maps (batch_size, embd_dim) to (batch_size, n_classes). Note: Make sure that this option faithfully reflects the reduction in the model definition. When inputting a string, available options are {'pick_first', 'pick_last', 'average'}.
dict_key_x : str, default='input_ids'
The dictionary key under which the input tensor x is stored. Only has effect when the model takes a MutableMapping as the input. Useful for Huggingface LLM models.
dict_key_y : str, default='labels'
The dictionary key under which the target tensor y is stored. Only has effect when the model takes a MutableMapping as the input. Useful for Huggingface LLM models.
backend : subclasses of CurvatureInterface
backend for access to curvature/Hessian approximations
last_layer_name : str, default=None
name of the model's last layer, if None it will be determined automatically
backend_kwargs : dict, default=None
arguments passed to the backend on initialization, for example to set the number of MC samples for stochastic approximations.

Ancestors

Subclasses

Instance variables

var prior_precision_diag : torch.Tensor

Obtain the diagonal prior precision p_0 constructed from either a scalar or diagonal prior precision.

Returns

prior_precision_diag : torch.Tensor
 

Methods

def fit(self, train_loader: DataLoader, override: bool = True, progress_bar: bool = False) ‑> None

Fit the local Laplace approximation at the parameters of the model.

Parameters

train_loader : torch.data.utils.DataLoader
each iterate is a training batch, either (X, y) tensors or a dict-like object containing keys as expressed by self.dict_key_x and self.dict_key_y. train_loader.dataset needs to be set to access N, size of the data set.
override : bool, default=True
whether to initialize H, loss, and n_data again; setting to False is useful for online learning settings to accumulate a sequential posterior approximation.
progress_bar : bool, default=False
 
def functional_variance_fast(self, X)

Should be overriden if there exists a trick to make this fast!

Parameters

X : torch.Tensor of shape (batch_size, input_dim)
 

Returns

f_var_diag : torch.Tensor of shape (batch_size, num_outputs)
Corresponding to the diagonal of the covariance matrix of the outputs
def state_dict(self) ‑> dict[str, typing.Any]
def load_state_dict(self, state_dict: dict[str, Any]) ‑> None

Inherited members

class FullLLLaplace (model: nn.Module, likelihood: Likelihood | str, sigma_noise: float | torch.Tensor = 1.0, prior_precision: float | torch.Tensor = 1.0, prior_mean: float | torch.Tensor = 0.0, temperature: float = 1.0, enable_backprop: bool = False, feature_reduction: FeatureReduction | str | None = None, dict_key_x: str = 'inputs_id', dict_key_y: str = 'labels', backend: type[CurvatureInterface] | None = None, last_layer_name: str | None = None, backend_kwargs: dict[str, Any] | None = None, asdl_fisher_kwargs: dict[str, Any] | None = None)

Last-layer Laplace approximation with full, i.e., dense, log likelihood Hessian approximation and hence posterior precision. Based on the chosen backend parameter, the full approximation can be, for example, a generalized Gauss-Newton matrix. Mathematically, we have P \in \mathbb{R}^{P \times P}. See FullLaplace, LLLaplace, and BaseLaplace for the full interface.

Ancestors

Inherited members

class KronLLLaplace (model: nn.Module, likelihood: Likelihood | str, sigma_noise: float | torch.Tensor = 1.0, prior_precision: float | torch.Tensor = 1.0, prior_mean: float | torch.Tensor = 0.0, temperature: float = 1.0, enable_backprop: bool = False, feature_reduction: FeatureReduction | str | None = None, dict_key_x: str = 'inputs_id', dict_key_y: str = 'labels', backend: type[CurvatureInterface] | None = None, last_layer_name: str | None = None, damping: bool = False, backend_kwargs: dict[str, Any] | None = None, asdl_fisher_kwargs: dict[str, Any] | None = None)

Last-layer Laplace approximation with Kronecker factored log likelihood Hessian approximation and hence posterior precision. Mathematically, we have for the last parameter group, i.e., torch.nn.Linear, that \P\approx Q \otimes H. See KronLaplace, LLLaplace, and BaseLaplace for the full interface and see Kron and KronDecomposed for the structure of the Kronecker factors. Kron is used to aggregate factors by summing up and KronDecomposed is used to add the prior, a Hessian factor (e.g. temperature), and computing posterior covariances, marginal likelihood, etc. Use of damping is possible by initializing or setting damping=True.

Ancestors

Inherited members

class DiagLLLaplace (model: nn.Module, likelihood: Likelihood | str, sigma_noise: float | torch.Tensor = 1.0, prior_precision: float | torch.Tensor = 1.0, prior_mean: float | torch.Tensor = 0.0, temperature: float = 1.0, enable_backprop: bool = False, feature_reduction: FeatureReduction | str | None = None, dict_key_x: str = 'inputs_id', dict_key_y: str = 'labels', backend: type[CurvatureInterface] | None = None, last_layer_name: str | None = None, backend_kwargs: dict[str, Any] | None = None, asdl_fisher_kwargs: dict[str, Any] | None = None)

Last-layer Laplace approximation with diagonal log likelihood Hessian approximation and hence posterior precision. Mathematically, we have P \approx \textrm{diag}(P). See DiagLaplace, LLLaplace, and BaseLaplace for the full interface.

Ancestors

Inherited members