Module laplace.baselaplace
Classes
class BaseLaplace (model: nn.Module, likelihood: Likelihood | str, sigma_noise: float | torch.Tensor = 1.0, prior_precision: float | torch.Tensor = 1.0, prior_mean: float | torch.Tensor = 0.0, temperature: float = 1.0, enable_backprop: bool = False, dict_key_x: str = 'input_ids', dict_key_y: str = 'labels', backend: type[CurvatureInterface] | None = None, backend_kwargs: dict[str, Any] | None = None, asdl_fisher_kwargs: dict[str, Any] | None = None)
-
Baseclass for all Laplace approximations in this library.
Parameters
model
:torch.nn.Module
likelihood
:Likelihood
orstr in {'classification', 'regression', 'reward_modeling'}
- determines the log likelihood Hessian approximation.
In the case of 'reward_modeling', it fits Laplace using the classification likelihood,
then does prediction as in regression likelihood. The model needs to be defined accordingly:
The forward pass during training takes
x.shape == (batch_size, 2, dim)
withy.shape = (batch_size,)
. Meanwhile, during evaluationx.shape == (batch_size, dim)
. Note that 'reward_modeling' only supportsKronLaplace
andDiagLaplace
. sigma_noise
:torch.Tensor
orfloat
, default=1
- observation noise for the regression setting; must be 1 for classification
prior_precision
:torch.Tensor
orfloat
, default=1
- prior precision of a Gaussian prior (= weight decay); can be scalar, per-layer, or diagonal in the most general case
prior_mean
:torch.Tensor
orfloat
, default=0
- prior mean of a Gaussian prior, useful for continual learning
temperature
:float
, default=1
- temperature of the likelihood; lower temperature leads to more concentrated posterior and vice versa.
enable_backprop
:bool
, default=False
- whether to enable backprop to the input
x
through the Laplace predictive. Useful for e.g. Bayesian optimization. dict_key_x
:str
, default='input_ids'
- The dictionary key under which the input tensor
x
is stored. Only has effect when the model takes aMutableMapping
as the input. Useful for Huggingface LLM models. dict_key_y
:str
, default='labels'
- The dictionary key under which the target tensor
y
is stored. Only has effect when the model takes aMutableMapping
as the input. Useful for Huggingface LLM models. backend
:subclasses
ofCurvatureInterface
- backend for access to curvature/Hessian approximations. Defaults to CurvlinopsGGN if None.
backend_kwargs
:dict
, default=None
- arguments passed to the backend on initialization, for example to set the number of MC samples for stochastic approximations.
asdl_fisher_kwargs
:dict
, default=None
- arguments passed to the ASDL backend specifically on initialization.
Subclasses
- laplace.baselaplace.FunctionalLaplace
- ParametricLaplace
Instance variables
var backend : CurvatureInterface
var log_likelihood : torch.Tensor
-
Compute log likelihood on the training data after
.fit()
has been called. The log likelihood is computed on-demand based on the loss and, for example, the observation noise which makes it differentiable in the latter for iterative updates.Returns
log_likelihood
:torch.Tensor
var prior_precision_diag : torch.Tensor
-
Obtain the diagonal prior precision p_0 constructed from either a scalar, layer-wise, or diagonal prior precision.
Returns
prior_precision_diag
:torch.Tensor
var prior_mean : torch.Tensor
var prior_precision : torch.Tensor
var sigma_noise : torch.Tensor
Methods
def fit(self, train_loader: DataLoader) ‑> None
def log_marginal_likelihood(self, prior_precision: torch.Tensor | None = None, sigma_noise: torch.Tensor | None = None) ‑> torch.Tensor
def predictive(self, x: torch.Tensor, pred_type: PredType | str, link_approx: LinkApprox | str, n_samples: int) ‑> torch.Tensor | tuple[torch.Tensor, torch.Tensor]
def optimize_prior_precision(self, pred_type: PredType | str, method: TuningMethod | str = TuningMethod.MARGLIK, n_steps: int = 100, lr: float = 0.1, init_prior_prec: float | torch.Tensor = 1.0, prior_structure: PriorStructure | str = PriorStructure.DIAG, val_loader: DataLoader | None = None, loss: torchmetrics.Metric | Callable[[torch.Tensor], torch.Tensor | float] | None = None, log_prior_prec_min: float = -4, log_prior_prec_max: float = 4, grid_size: int = 100, link_approx: LinkApprox | str = LinkApprox.PROBIT, n_samples: int = 100, verbose: bool = False, progress_bar: bool = False) ‑> None
-
Optimize the prior precision post-hoc using the
method
specified by the user.Parameters
pred_type
:PredType
orstr in {'glm', 'nn'}
- type of posterior predictive, linearized GLM predictive or neural network sampling predictiv. The GLM predictive is consistent with the curvature approximations used here.
method
:TuningMethod
orstr in {'marglik', 'gridsearch'}
, default=PredType.MARGLIK
- specifies how the prior precision should be optimized.
n_steps
:int
, default=100
- the number of gradient descent steps to take.
lr
:float
, default=1e-1
- the learning rate to use for gradient descent.
init_prior_prec
:float
ortensor
, default=1.0
- initial prior precision before the first optimization step.
prior_structure
:PriorStructure
orstr in {'scalar', 'layerwise', 'diag'}
, default=PriorStructure.SCALAR
- if init_prior_prec is scalar, the prior precision is optimized with this structure. otherwise, the structure of init_prior_prec is maintained.
val_loader
:torch.data.utils.DataLoader
, default=None
- DataLoader for the validation set; each iterate is a training batch (X, y).
loss
:callable
ortorchmetrics.Metric
, default=None
- loss function to use for CV. If callable, the loss is computed offline (memory intensive).
If torchmetrics.Metric, running loss is computed (efficient). The default
depends on the likelihood:
RunningNLLMetric()
for classification and reward modeling, runningMeanSquaredError()
for regression. log_prior_prec_min
:float
, default=-4
- lower bound of gridsearch interval.
log_prior_prec_max
:float
, default=4
- upper bound of gridsearch interval.
grid_size
:int
, default=100
- number of values to consider inside the gridsearch interval.
link_approx
:LinkApprox
orstr in {'mc', 'probit', 'bridge'}
, default=LinkApprox.PROBIT
- how to approximate the classification link function for the
'glm'
. Forpred_type='nn'
, only'mc'
is possible. n_samples
:int
, default=100
- number of samples for
link_approx='mc'
. verbose
:bool
, default=False
- if true, the optimized prior precision will be printed (can be a large tensor if the prior has a diagonal covariance).
progress_bar
:bool
, default=False
- whether to show a progress bar; updated at every batch-Hessian computation.
Useful for very large model and large amount of data, esp. when
subset_of_weights='all'
.
class ParametricLaplace (model: nn.Module, likelihood: Likelihood | str, sigma_noise: float | torch.Tensor = 1.0, prior_precision: float | torch.Tensor = 1.0, prior_mean: float | torch.Tensor = 0.0, temperature: float = 1.0, enable_backprop: bool = False, dict_key_x: str = 'inputs_id', dict_key_y: str = 'labels', backend: type[CurvatureInterface] | None = None, backend_kwargs: dict[str, Any] | None = None, asdl_fisher_kwargs: dict[str, Any] | None = None)
-
Parametric Laplace class.
Subclasses need to specify how the Hessian approximation is initialized, how to add up curvature over training data, how to sample from the Laplace approximation, and how to compute the functional variance.
A Laplace approximation is represented by a MAP which is given by the
model
parameter and a posterior precision or covariance specifying a Gaussian distribution \mathcal{N}(\theta_{MAP}, P^{-1}). The goal of this class is to compute the posterior precision P which sums as P = \sum_{n=1}^N \nabla^2_\theta \log p(\mathcal{D}_n \mid \theta) \vert_{\theta_{MAP}} + \nabla^2_\theta \log p(\theta) \vert_{\theta_{MAP}}. Every subclass implements different approximations to the log likelihood Hessians, for example, a diagonal one. The prior is assumed to be Gaussian and therefore we have a simple form for \nabla^2_\theta \log p(\theta) \vert_{\theta_{MAP}} = P_0 . In particular, we assume a scalar, layer-wise, or diagonal prior precision so that in all cases P_0 = \textrm{diag}(p_0) and the structure of p_0 can be varied.Ancestors
Subclasses
Instance variables
var scatter : torch.Tensor
-
Computes the scatter, a term of the log marginal likelihood that corresponds to L-2 regularization:
scatter
= (\theta_{MAP} - \mu_0)^{T} P_0 (\theta_{MAP} - \mu_0) .Returns
scatter
:torch.Tensor
var log_det_prior_precision : torch.Tensor
-
Compute log determinant of the prior precision \log \det P_0
Returns
log_det
:torch.Tensor
var log_det_posterior_precision : torch.Tensor
-
Compute log determinant of the posterior precision \log \det P which depends on the subclasses structure used for the Hessian approximation.
Returns
log_det
:torch.Tensor
var log_det_ratio : torch.Tensor
-
Compute the log determinant ratio, a part of the log marginal likelihood. \log \frac{\det P}{\det P_0} = \log \det P - \log \det P_0
Returns
log_det_ratio
:torch.Tensor
var posterior_precision : torch.Tensor
-
Compute or return the posterior precision P.
Returns
posterior_prec
:torch.Tensor
Methods
def fit(self, train_loader: DataLoader, override: bool = True, progress_bar: bool = False) ‑> None
-
Fit the local Laplace approximation at the parameters of the model.
Parameters
train_loader
:torch.data.utils.DataLoader
- each iterate is a training batch, either
(X, y)
tensors or a dict-like object containing keys as expressed byself.dict_key_x
andself.dict_key_y
.train_loader.dataset
needs to be set to access N, size of the data set. override
:bool
, default=True
- whether to initialize H, loss, and n_data again; setting to False is useful for online learning settings to accumulate a sequential posterior approximation.
progress_bar
:bool
, default=False
- whether to show a progress bar; updated at every batch-Hessian computation.
Useful for very large model and large amount of data, esp. when
subset_of_weights='all'
.
def square_norm(self, value) ‑> torch.Tensor
-
Compute the square norm under post. Precision with
value-self.mean
as 𝛥: \Delta^ op P \Delta Returns
square_form
def log_prob(self, value: torch.Tensor, normalized: bool = True) ‑> torch.Tensor
-
Compute the log probability under the (current) Laplace approximation.
Parameters
value
:torch.Tensor
normalized
:bool
, default=True
- whether to return log of a properly normalized Gaussian or just the
terms that depend on
value
.
Returns
log_prob
:torch.Tensor
def log_marginal_likelihood(self, prior_precision: torch.Tensor | None = None, sigma_noise: torch.Tensor | None = None) ‑> torch.Tensor
-
Compute the Laplace approximation to the log marginal likelihood subject to specific Hessian approximations that subclasses implement. Requires that the Laplace approximation has been fit before. The resulting torch.Tensor is differentiable in
prior_precision
andsigma_noise
if these have gradients enabled. By passingprior_precision
orsigma_noise
, the current value is overwritten. This is useful for iterating on the log marginal likelihood.Parameters
prior_precision
:torch.Tensor
, optional- prior precision if should be changed from current
prior_precision
value sigma_noise
:torch.Tensor
, optional- observation noise standard deviation if should be changed
Returns
log_marglik
:torch.Tensor
def predictive_samples(self, x: torch.Tensor | MutableMapping[str, torch.Tensor | Any], pred_type: PredType | str = PredType.GLM, n_samples: int = 100, diagonal_output: bool = False, generator: torch.Generator | None = None) ‑> torch.Tensor
-
Sample from the posterior predictive on input data
x
. Can be used, for example, for Thompson sampling.Parameters
x
:torch.Tensor
orMutableMapping
- input data
(batch_size, input_shape)
pred_type
:{'glm', 'nn'}
, default='glm'
- type of posterior predictive, linearized GLM predictive or neural network sampling predictive. The GLM predictive is consistent with the curvature approximations used here.
n_samples
:int
- number of samples
diagonal_output
:bool
- whether to use a diagonalized glm posterior predictive on the outputs.
Only applies when
pred_type='glm'
. generator
:torch.Generator
, optional- random number generator to control the samples (if sampling used)
Returns
samples
:torch.Tensor
- samples
(n_samples, batch_size, output_shape)
def functional_variance(self, Js: torch.Tensor) ‑> torch.Tensor
-
Compute functional variance for the
'glm'
predictive:f_var[i] = Js[i] @ P.inv() @ Js[i].T
, which is a output x output predictive covariance matrix. Mathematically, we have for a single Jacobian \mathcal{J} = \nabla_\theta f(x;\theta)\vert_{\theta_{MAP}} the output covariance matrix \mathcal{J} P^{-1} \mathcal{J}^T .Parameters
Js
:torch.Tensor
- Jacobians of model output wrt parameters
(batch, outputs, parameters)
Returns
f_var
:torch.Tensor
- output covariance
(batch, outputs, outputs)
def functional_covariance(self, Js: torch.Tensor) ‑> torch.Tensor
-
Compute functional covariance for the
'glm'
predictive:f_cov = Js @ P.inv() @ Js.T
, which is a batchoutput x batchoutput predictive covariance matrix.This emulates the GP posterior covariance N([f(x1), …,f(xm)], Cov[f(x1), …, f(xm)]). Useful for joint predictions, such as in batched Bayesian optimization.
Parameters
Js
:torch.Tensor
- Jacobians of model output wrt parameters
(batch*outputs, parameters)
Returns
f_cov
:torch.Tensor
- output covariance
(batch*outputs, batch*outputs)
def sample(self, n_samples: int = 100, generator: torch.Generator | None = None) ‑> torch.Tensor
-
Sample from the Laplace posterior approximation, i.e., \theta \sim \mathcal{N}(\theta_{MAP}, P^{-1}).
Parameters
n_samples
:int
, default=100
- number of samples
generator
:torch.Generator
, optional- random number generator to control the samples
Returns
samples
:torch.Tensor
def state_dict(self) ‑> dict[str, typing.Any]
def load_state_dict(self, state_dict: dict[str, Any]) ‑> None
Inherited members
class FullLaplace (model: nn.Module, likelihood: Likelihood | str, sigma_noise: float | torch.Tensor = 1.0, prior_precision: float | torch.Tensor = 1.0, prior_mean: float | torch.Tensor = 0.0, temperature: float = 1.0, enable_backprop: bool = False, dict_key_x: str = 'input_ids', dict_key_y: str = 'labels', backend: type[CurvatureInterface] | None = None, backend_kwargs: dict[str, Any] | None = None)
-
Laplace approximation with full, i.e., dense, log likelihood Hessian approximation and hence posterior precision. Based on the chosen
backend
parameter, the full approximation can be, for example, a generalized Gauss-Newton matrix. Mathematically, we have P \in \mathbb{R}^{P \times P}. SeeBaseLaplace
for the full interface.Ancestors
Subclasses
Instance variables
var posterior_scale : torch.Tensor
-
Posterior scale (square root of the covariance), i.e., P^{-\frac{1}{2}}.
Returns
scale
:torch.tensor
(parameters, parameters)
var posterior_covariance : torch.Tensor
-
Posterior covariance, i.e., P^{-1}.
Returns
covariance
:torch.tensor
(parameters, parameters)
var posterior_precision : torch.Tensor
-
Posterior precision P.
Returns
precision
:torch.tensor
(parameters, parameters)
Inherited members
class KronLaplace (model: nn.Module, likelihood: Likelihood | str, sigma_noise: float | torch.Tensor = 1.0, prior_precision: float | torch.Tensor = 1.0, prior_mean: float | torch.Tensor = 0.0, temperature: float = 1.0, enable_backprop: bool = False, dict_key_x: str = 'inputs_id', dict_key_y: str = 'labels', backend: type[CurvatureInterface] | None = None, damping: bool = False, backend_kwargs: dict[str, Any] | None = None, asdl_fisher_kwargs: dict[str, Any] | None = None)
-
Laplace approximation with Kronecker factored log likelihood Hessian approximation and hence posterior precision. Mathematically, we have for each parameter group, e.g., torch.nn.Module, that \P\approx Q \otimes H. See
BaseLaplace
for the full interface and seeKron
andKronDecomposed
for the structure of the Kronecker factors.Kron
is used to aggregate factors by summing up andKronDecomposed
is used to add the prior, a Hessian factor (e.g. temperature), and computing posterior covariances, marginal likelihood, etc. Damping can be enabled by settingdamping=True
.Ancestors
Subclasses
Instance variables
var posterior_precision : KronDecomposed
var prior_precision : torch.Tensor
Methods
def state_dict(self) ‑> dict[str, typing.Any]
def load_state_dict(self, state_dict: dict[str, Any])
Inherited members
class DiagLaplace (model: nn.Module, likelihood: Likelihood | str, sigma_noise: float | torch.Tensor = 1.0, prior_precision: float | torch.Tensor = 1.0, prior_mean: float | torch.Tensor = 0.0, temperature: float = 1.0, enable_backprop: bool = False, dict_key_x: str = 'inputs_id', dict_key_y: str = 'labels', backend: type[CurvatureInterface] | None = None, backend_kwargs: dict[str, Any] | None = None, asdl_fisher_kwargs: dict[str, Any] | None = None)
-
Laplace approximation with diagonal log likelihood Hessian approximation and hence posterior precision. Mathematically, we have P \approx \textrm{diag}(P). See
BaseLaplace
for the full interface.Ancestors
Subclasses
Instance variables
var posterior_precision : torch.Tensor
-
Diagonal posterior precision p.
Returns
precision
:torch.tensor
(parameters)
var posterior_scale : torch.Tensor
-
Diagonal posterior scale \sqrt{p^{-1}}.
Returns
precision
:torch.tensor
(parameters)
var posterior_variance : torch.Tensor
-
Diagonal posterior variance p^{-1}.
Returns
precision
:torch.tensor
(parameters)
Inherited members
class LowRankLaplace (model: nn.Module, likelihood: Likelihood | str, sigma_noise: float | torch.Tensor = 1, prior_precision: float | torch.Tensor = 1, prior_mean: float | torch.Tensor = 0, temperature: float = 1, enable_backprop: bool = False, dict_key_x: str = 'inputs_id', dict_key_y: str = 'labels', backend=laplace.curvature.asdfghjkl.AsdfghjklHessian, backend_kwargs: dict[str, Any] | None = None)
-
Laplace approximation with low-rank log likelihood Hessian (approximation). The low-rank matrix is represented by an eigendecomposition (vecs, values). Based on the chosen
backend
, either a true Hessian or, for example, GGN approximation could be used. The posterior precision is computed as P = V diag(l) V^T + P_0. To sample, compute the functional variance, and log determinant, algebraic tricks are usedto reduce the costs of inversion to the that of a K imes K matrix if we have a rank of K.See
BaseLaplace
for the full interface.Ancestors
Instance variables
var V : torch.Tensor
var Kinv : torch.Tensor
var posterior_precision : tuple[tuple[torch.Tensor, torch.Tensor], torch.Tensor]
-
Return correctly scaled posterior precision that would be constructed as H[0] @ diag(H[1]) @ H[0].T + self.prior_precision_diag.
Returns
H
:tuple(eigenvectors, eigenvalues)
- scaled self.H with temperature and loss factors.
prior_precision_diag
:torch.Tensor
- diagonal prior precision shape
parameters
to be added to H.
Inherited members