Module laplace.lllaplace
Classes
class LLLaplace (model: nn.Module, likelihood: Likelihood | str, sigma_noise: float | torch.Tensor = 1.0, prior_precision: float | torch.Tensor = 1.0, prior_mean: float | torch.Tensor = 0.0, temperature: float = 1.0, enable_backprop: bool = False, feature_reduction: FeatureReduction | str | None = None, dict_key_x: str = 'inputs_id', dict_key_y: str = 'labels', backend: type[CurvatureInterface] | None = None, last_layer_name: str | None = None, backend_kwargs: dict[str, Any] | None = None, asdl_fisher_kwargs: dict[str, Any] | None = None)
-
Baseclass for all last-layer Laplace approximations in this library. Subclasses specify the structure of the Hessian approximation. See
BaseLaplace
for the full interface.A Laplace approximation is represented by a MAP which is given by the
model
parameter and a posterior precision or covariance specifying a Gaussian distribution \mathcal{N}(\theta_{MAP}, P^{-1}). Here, only the parameters of the last layer of the neural network are treated probabilistically. The goal of this class is to compute the posterior precision P which sums as P = \sum_{n=1}^N \nabla^2_\theta \log p(\mathcal{D}_n \mid \theta) \vert_{\theta_{MAP}} + \nabla^2_\theta \log p(\theta) \vert_{\theta_{MAP}}. Every subclass implements different approximations to the log likelihood Hessians, for example, a diagonal one. The prior is assumed to be Gaussian and therefore we have a simple form for \nabla^2_\theta \log p(\theta) \vert_{\theta_{MAP}} = P_0 . In particular, we assume a scalar or diagonal prior precision so that in all cases P_0 = \textrm{diag}(p_0) and the structure of p_0 can be varied.Parameters
model
:torch.nn.Module
orFeatureExtractor
likelihood
:Likelihood
or{'classification', 'regression'}
- determines the log likelihood Hessian approximation
sigma_noise
:torch.Tensor
orfloat
, default=1
- observation noise for the regression setting; must be 1 for classification
prior_precision
:torch.Tensor
orfloat
, default=1
- prior precision of a Gaussian prior (= weight decay); can be scalar, per-layer, or diagonal in the most general case
prior_mean
:torch.Tensor
orfloat
, default=0
- prior mean of a Gaussian prior, useful for continual learning
temperature
:float
, default=1
- temperature of the likelihood; lower temperature leads to more concentrated posterior and vice versa.
enable_backprop
:bool
, default=False
- whether to enable backprop to the input
x
through the Laplace predictive. Useful for e.g. Bayesian optimization. feature_reduction
:FeatureReduction
orstr
, optional, default=None
- when the last-layer
features
is a tensor of dim >= 3, this tells how to reduce it into a dim-2 tensor. E.g. in LLMs for non-language modeling problems, the penultultimate output is a tensor of shape(batch_size, seq_len, embd_dim)
. But the last layer maps(batch_size, embd_dim)
to(batch_size, n_classes)
. Note: Make sure that this option faithfully reflects the reduction in the model definition. When inputting a string, available options are{'pick_first', 'pick_last', 'average'}
. dict_key_x
:str
, default='input_ids'
- The dictionary key under which the input tensor
x
is stored. Only has effect when the model takes aMutableMapping
as the input. Useful for Huggingface LLM models. dict_key_y
:str
, default='labels'
- The dictionary key under which the target tensor
y
is stored. Only has effect when the model takes aMutableMapping
as the input. Useful for Huggingface LLM models. backend
:subclasses
ofCurvatureInterface
- backend for access to curvature/Hessian approximations
last_layer_name
:str
, default=None
- name of the model's last layer, if None it will be determined automatically
backend_kwargs
:dict
, default=None
- arguments passed to the backend on initialization, for example to set the number of MC samples for stochastic approximations.
Ancestors
Subclasses
Instance variables
var prior_precision_diag : torch.Tensor
-
Obtain the diagonal prior precision p_0 constructed from either a scalar or diagonal prior precision.
Returns
prior_precision_diag
:torch.Tensor
Methods
def fit(self, train_loader: DataLoader, override: bool = True, progress_bar: bool = False) ‑> None
-
Fit the local Laplace approximation at the parameters of the model.
Parameters
train_loader
:torch.data.utils.DataLoader
- each iterate is a training batch, either
(X, y)
tensors or a dict-like object containing keys as expressed byself.dict_key_x
andself.dict_key_y
.train_loader.dataset
needs to be set to access N, size of the data set. override
:bool
, default=True
- whether to initialize H, loss, and n_data again; setting to False is useful for online learning settings to accumulate a sequential posterior approximation.
progress_bar
:bool
, default=False
def functional_variance_fast(self, X)
-
Should be overriden if there exists a trick to make this fast!
Parameters
X
:torch.Tensor
ofshape (batch_size, input_dim)
Returns
f_var_diag
:torch.Tensor
ofshape (batch_size, num_outputs)
- Corresponding to the diagonal of the covariance matrix of the outputs
def state_dict(self) ‑> dict[str, typing.Any]
def load_state_dict(self, state_dict: dict[str, Any]) ‑> None
Inherited members
class FullLLLaplace (model: nn.Module, likelihood: Likelihood | str, sigma_noise: float | torch.Tensor = 1.0, prior_precision: float | torch.Tensor = 1.0, prior_mean: float | torch.Tensor = 0.0, temperature: float = 1.0, enable_backprop: bool = False, feature_reduction: FeatureReduction | str | None = None, dict_key_x: str = 'inputs_id', dict_key_y: str = 'labels', backend: type[CurvatureInterface] | None = None, last_layer_name: str | None = None, backend_kwargs: dict[str, Any] | None = None, asdl_fisher_kwargs: dict[str, Any] | None = None)
-
Last-layer Laplace approximation with full, i.e., dense, log likelihood Hessian approximation and hence posterior precision. Based on the chosen
backend
parameter, the full approximation can be, for example, a generalized Gauss-Newton matrix. Mathematically, we have P \in \mathbb{R}^{P \times P}. SeeFullLaplace
,LLLaplace
, andBaseLaplace
for the full interface.Ancestors
Inherited members
class KronLLLaplace (model: nn.Module, likelihood: Likelihood | str, sigma_noise: float | torch.Tensor = 1.0, prior_precision: float | torch.Tensor = 1.0, prior_mean: float | torch.Tensor = 0.0, temperature: float = 1.0, enable_backprop: bool = False, feature_reduction: FeatureReduction | str | None = None, dict_key_x: str = 'inputs_id', dict_key_y: str = 'labels', backend: type[CurvatureInterface] | None = None, last_layer_name: str | None = None, damping: bool = False, backend_kwargs: dict[str, Any] | None = None, asdl_fisher_kwargs: dict[str, Any] | None = None)
-
Last-layer Laplace approximation with Kronecker factored log likelihood Hessian approximation and hence posterior precision. Mathematically, we have for the last parameter group, i.e., torch.nn.Linear, that \P\approx Q \otimes H. See
KronLaplace
,LLLaplace
, andBaseLaplace
for the full interface and seeKron
andKronDecomposed
for the structure of the Kronecker factors.Kron
is used to aggregate factors by summing up andKronDecomposed
is used to add the prior, a Hessian factor (e.g. temperature), and computing posterior covariances, marginal likelihood, etc. Use ofdamping
is possible by initializing or settingdamping=True
.Ancestors
Inherited members
class DiagLLLaplace (model: nn.Module, likelihood: Likelihood | str, sigma_noise: float | torch.Tensor = 1.0, prior_precision: float | torch.Tensor = 1.0, prior_mean: float | torch.Tensor = 0.0, temperature: float = 1.0, enable_backprop: bool = False, feature_reduction: FeatureReduction | str | None = None, dict_key_x: str = 'inputs_id', dict_key_y: str = 'labels', backend: type[CurvatureInterface] | None = None, last_layer_name: str | None = None, backend_kwargs: dict[str, Any] | None = None, asdl_fisher_kwargs: dict[str, Any] | None = None)
-
Last-layer Laplace approximation with diagonal log likelihood Hessian approximation and hence posterior precision. Mathematically, we have P \approx \textrm{diag}(P). See
DiagLaplace
,LLLaplace
, andBaseLaplace
for the full interface.Ancestors
Inherited members