Module laplace.baselaplace

Classes

class BaseLaplace (model, likelihood, sigma_noise=1.0, prior_precision=1.0, prior_mean=0.0, temperature=1.0, backend=None, backend_kwargs=None)

Baseclass for all Laplace approximations in this library.

Parameters

model : torch.nn.Module
 
likelihood : {'classification', 'regression'}
determines the log likelihood Hessian approximation
sigma_noise : torch.Tensor or float, default=1
observation noise for the regression setting; must be 1 for classification
prior_precision : torch.Tensor or float, default=1
prior precision of a Gaussian prior (= weight decay); can be scalar, per-layer, or diagonal in the most general case
prior_mean : torch.Tensor or float, default=0
prior mean of a Gaussian prior, useful for continual learning
temperature : float, default=1
temperature of the likelihood; lower temperature leads to more concentrated posterior and vice versa.
backend : subclasses of CurvatureInterface
backend for access to curvature/Hessian approximations
backend_kwargs : dict, default=None
arguments passed to the backend on initialization, for example to set the number of MC samples for stochastic approximations.

Subclasses

Instance variables

var backend
var log_likelihood

Compute log likelihood on the training data after .fit() has been called. The log likelihood is computed on-demand based on the loss and, for example, the observation noise which makes it differentiable in the latter for iterative updates.

Returns

log_likelihood : torch.Tensor
 
var prior_precision_diag

Obtain the diagonal prior precision p_0 constructed from either a scalar, layer-wise, or diagonal prior precision.

Returns

prior_precision_diag : torch.Tensor
 
var prior_mean
var prior_precision
var sigma_noise

Methods

def fit(self, train_loader)
def log_marginal_likelihood(self, prior_precision=None, sigma_noise=None)
def predictive(self, x, pred_type, link_approx, n_samples)
def optimize_prior_precision_base(self, pred_type, method='marglik', n_steps=100, lr=0.1, init_prior_prec=1.0, val_loader=None, loss=<function get_nll>, log_prior_prec_min=-4, log_prior_prec_max=4, grid_size=100, link_approx='probit', n_samples=100, verbose=False, cv_loss_with_var=False)

Optimize the prior precision post-hoc using the method specified by the user.

Parameters

pred_type : {'glm', 'nn', 'gp'}, default='glm'
type of posterior predictive, linearized GLM predictive or neural network sampling predictive or Gaussian Process (GP) inference. The GLM predictive is consistent with the curvature approximations used here.
method : {'marglik', 'CV'}, default='marglik'
specifies how the prior precision should be optimized.
n_steps : int, default=100
the number of gradient descent steps to take.
lr : float, default=1e-1
the learning rate to use for gradient descent.
init_prior_prec : float, default=1.0
initial prior precision before the first optimization step.
val_loader : torch.data.utils.DataLoader, default=None
DataLoader for the validation set; each iterate is a training batch (X, y).
loss : callable, default=get_nll
loss function to use for CV.
cv_loss_with_var : bool, default=False
if true, loss takes three arguments loss(output_mean, output_var, target), otherwise, loss takes two arguments loss(output_mean, target)
log_prior_prec_min : float, default=-4
lower bound of gridsearch interval for CV.
log_prior_prec_max : float, default=4
upper bound of gridsearch interval for CV.
grid_size : int, default=100
number of values to consider inside the gridsearch interval for CV.
link_approx : {'mc', 'probit', 'bridge'}, default='probit'
how to approximate the classification link function for the 'glm'. For pred_type='nn', only 'mc' is possible.
n_samples : int, default=100
number of samples for link_approx='mc'.
verbose : bool, default=False
if true, the optimized prior precision will be printed (can be a large tensor if the prior has a diagonal covariance).
class ParametricLaplace (model, likelihood, sigma_noise=1.0, prior_precision=1.0, prior_mean=0.0, temperature=1.0, backend=None, backend_kwargs=None)

Parametric Laplace class.

Subclasses need to specify how the Hessian approximation is initialized, how to add up curvature over training data, how to sample from the Laplace approximation, and how to compute the functional variance.

A Laplace approximation is represented by a MAP which is given by the model parameter and a posterior precision or covariance specifying a Gaussian distribution \mathcal{N}(\theta_{MAP}, P^{-1}). The goal of this class is to compute the posterior precision P which sums as P = \sum_{n=1}^N \nabla^2_\theta \log p(\mathcal{D}_n \mid \theta) \vert_{\theta_{MAP}} + \nabla^2_\theta \log p(\theta) \vert_{\theta_{MAP}}. Every subclass implements different approximations to the log likelihood Hessians, for example, a diagonal one. The prior is assumed to be Gaussian and therefore we have a simple form for \nabla^2_\theta \log p(\theta) \vert_{\theta_{MAP}} = P_0 . In particular, we assume a scalar, layer-wise, or diagonal prior precision so that in all cases P_0 = \textrm{diag}(p_0) and the structure of p_0 can be varied.

Ancestors

Subclasses

Instance variables

var scatter

Computes the scatter, a term of the log marginal likelihood that corresponds to L-2 regularization: scatter = (\theta_{MAP} - \mu_0)^{T} P_0 (\theta_{MAP} - \mu_0) .

Returns

[type] [description]

var log_det_prior_precision

Compute log determinant of the prior precision \log \det P_0

Returns

log_det : torch.Tensor
 
var log_det_posterior_precision

Compute log determinant of the posterior precision \log \det P which depends on the subclasses structure used for the Hessian approximation.

Returns

log_det : torch.Tensor
 
var log_det_ratio

Compute the log determinant ratio, a part of the log marginal likelihood. \log \frac{\det P}{\det P_0} = \log \det P - \log \det P_0

Returns

log_det_ratio : torch.Tensor
 
var posterior_precision

Compute or return the posterior precision P.

Returns

posterior_prec : torch.Tensor
 

Methods

def fit(self, train_loader, override=True)

Fit the local Laplace approximation at the parameters of the model.

Parameters

train_loader : torch.data.utils.DataLoader
each iterate is a training batch (X, y); train_loader.dataset needs to be set to access N, size of the data set
override : bool, default=True
whether to initialize H, loss, and n_data again; setting to False is useful for online learning settings to accumulate a sequential posterior approximation.
def square_norm(self, value)

Compute the square norm under post. Precision with value-self.mean as 𝛥: \Delta^ op P \Delta Returns


square_form
 
def log_prob(self, value, normalized=True)

Compute the log probability under the (current) Laplace approximation.

Parameters

normalized : bool, default=True
whether to return log of a properly normalized Gaussian or just the terms that depend on value.

Returns

log_prob : torch.Tensor
 
def log_marginal_likelihood(self, prior_precision=None, sigma_noise=None)

Compute the Laplace approximation to the log marginal likelihood subject to specific Hessian approximations that subclasses implement. Requires that the Laplace approximation has been fit before. The resulting torch.Tensor is differentiable in prior_precision and sigma_noise if these have gradients enabled. By passing prior_precision or sigma_noise, the current value is overwritten. This is useful for iterating on the log marginal likelihood.

Parameters

prior_precision : torch.Tensor, optional
prior precision if should be changed from current prior_precision value
sigma_noise : [type], optional
observation noise standard deviation if should be changed

Returns

log_marglik : torch.Tensor
 
def predictive_samples(self, x, pred_type='glm', n_samples=100, diagonal_output=False, generator=None)

Sample from the posterior predictive on input data x. Can be used, for example, for Thompson sampling.

Parameters

x : torch.Tensor
input data (batch_size, input_shape)
pred_type : {'glm', 'nn'}, default='glm'
type of posterior predictive, linearized GLM predictive or neural network sampling predictive. The GLM predictive is consistent with the curvature approximations used here.
n_samples : int
number of samples
diagonal_output : bool
whether to use a diagonalized glm posterior predictive on the outputs. Only applies when pred_type='glm'.
generator : torch.Generator, optional
random number generator to control the samples (if sampling used)

Returns

samples : torch.Tensor
samples (n_samples, batch_size, output_shape)
def functional_variance(self, Jacs)

Compute functional variance for the 'glm' predictive: f_var[i] = Jacs[i] @ P.inv() @ Jacs[i].T, which is a output x output predictive covariance matrix. Mathematically, we have for a single Jacobian \mathcal{J} = \nabla_\theta f(x;\theta)\vert_{\theta_{MAP}} the output covariance matrix \mathcal{J} P^{-1} \mathcal{J}^T .

Parameters

Jacs : torch.Tensor
Jacobians of model output wrt parameters (batch, outputs, parameters)

Returns

f_var : torch.Tensor
output covariance (batch, outputs, outputs)
def sample(self, n_samples=100)

Sample from the Laplace posterior approximation, i.e., \theta \sim \mathcal{N}(\theta_{MAP}, P^{-1}).

Parameters

n_samples : int, default=100
number of samples
def optimize_prior_precision(self, method='marglik', pred_type='glm', n_steps=100, lr=0.1, init_prior_prec=1.0, val_loader=None, loss=<function get_nll>, log_prior_prec_min=-4, log_prior_prec_max=4, grid_size=100, link_approx='probit', n_samples=100, verbose=False, cv_loss_with_var=False)

Inherited members

class FullLaplace (model, likelihood, sigma_noise=1.0, prior_precision=1.0, prior_mean=0.0, temperature=1.0, backend=None, backend_kwargs=None)

Laplace approximation with full, i.e., dense, log likelihood Hessian approximation and hence posterior precision. Based on the chosen backend parameter, the full approximation can be, for example, a generalized Gauss-Newton matrix. Mathematically, we have P \in \mathbb{R}^{P \times P}. See BaseLaplace for the full interface.

Ancestors

Subclasses

Instance variables

var posterior_scale

Posterior scale (square root of the covariance), i.e., P^{-\frac{1}{2}}.

Returns

scale : torch.tensor
(parameters, parameters)
var posterior_covariance

Posterior covariance, i.e., P^{-1}.

Returns

covariance : torch.tensor
(parameters, parameters)
var posterior_precision

Posterior precision P.

Returns

precision : torch.tensor
(parameters, parameters)

Inherited members

class KronLaplace (model, likelihood, sigma_noise=1.0, prior_precision=1.0, prior_mean=0.0, temperature=1.0, backend=None, damping=False, **backend_kwargs)

Laplace approximation with Kronecker factored log likelihood Hessian approximation and hence posterior precision. Mathematically, we have for each parameter group, e.g., torch.nn.Module, that \P\approx Q \otimes H. See BaseLaplace for the full interface and see Kron and KronDecomposed for the structure of the Kronecker factors. Kron is used to aggregate factors by summing up and KronDecomposed is used to add the prior, a Hessian factor (e.g. temperature), and computing posterior covariances, marginal likelihood, etc. Damping can be enabled by setting damping=True.

Ancestors

Subclasses

Instance variables

var posterior_precision

Kronecker factored Posterior precision P.

Returns

precision : KronDecomposed
 
var prior_precision

Inherited members

class DiagLaplace (model, likelihood, sigma_noise=1.0, prior_precision=1.0, prior_mean=0.0, temperature=1.0, backend=None, backend_kwargs=None)

Laplace approximation with diagonal log likelihood Hessian approximation and hence posterior precision. Mathematically, we have P \approx \textrm{diag}(P). See BaseLaplace for the full interface.

Ancestors

Subclasses

Instance variables

var posterior_precision

Diagonal posterior precision p.

Returns

precision : torch.tensor
(parameters)
var posterior_scale

Diagonal posterior scale \sqrt{p^{-1}}.

Returns

precision : torch.tensor
(parameters)
var posterior_variance

Diagonal posterior variance p^{-1}.

Returns

precision : torch.tensor
(parameters)

Inherited members

class LowRankLaplace (model, likelihood, sigma_noise=1, prior_precision=1, prior_mean=0, temperature=1, backend=laplace.curvature.asdl.AsdlHessian, backend_kwargs=None)

Laplace approximation with low-rank log likelihood Hessian (approximation). The low-rank matrix is represented by an eigendecomposition (vecs, values). Based on the chosen backend, either a true Hessian or, for example, GGN approximation could be used. The posterior precision is computed as P = V diag(l) V^T + P_0. To sample, compute the functional variance, and log determinant, algebraic tricks are usedto reduce the costs of inversion to the that of a K imes K matrix if we have a rank of K.

See BaseLaplace for the full interface.

Ancestors

Instance variables

var V
var Kinv
var posterior_precision

Return correctly scaled posterior precision that would be constructed as H[0] @ diag(H[1]) @ H[0].T + self.prior_precision_diag.

Returns

H : tuple(eigenvectors, eigenvalues)
scaled self.H with temperature and loss factors.
prior_precision_diag : torch.Tensor
diagonal prior precision shape parameters to be added to H.

Inherited members