Module laplace.subnetlaplace

Classes

class SubnetLaplace (model, likelihood, subnetwork_indices, sigma_noise=1.0, prior_precision=1.0, prior_mean=0.0, temperature=1.0, backend=None, backend_kwargs=None)

Class for subnetwork Laplace, which computes the Laplace approximation over just a subset of the model parameters (i.e. a subnetwork within the neural network), as proposed in [1]. Subnetwork Laplace can only be used with either a full or a diagonal Hessian approximation.

A Laplace approximation is represented by a MAP which is given by the model parameter and a posterior precision or covariance specifying a Gaussian distribution \mathcal{N}(\theta_{MAP}, P^{-1}). Here, only a subset of the model parameters (i.e. a subnetwork of the neural network) are treated probabilistically. The goal of this class is to compute the posterior precision P which sums as P = \sum_{n=1}^N \nabla^2_\theta \log p(\mathcal{D}_n \mid \theta) \vert_{\theta_{MAP}} + \nabla^2_\theta \log p(\theta) \vert_{\theta_{MAP}}. The prior is assumed to be Gaussian and therefore we have a simple form for \nabla^2_\theta \log p(\theta) \vert_{\theta_{MAP}} = P_0 . In particular, we assume a scalar or diagonal prior precision so that in all cases P_0 = \textrm{diag}(p_0) and the structure of p_0 can be varied.

The subnetwork Laplace approximation only supports a full, i.e., dense, log likelihood Hessian approximation and hence posterior precision. Based on the chosen backend parameter, the full approximation can be, for example, a generalized Gauss-Newton matrix. Mathematically, we have P \in \mathbb{R}^{P \times P}. See FullLaplace and BaseLaplace for the full interface.

References

[1] Daxberger, E., Nalisnick, E., Allingham, JU., Antorán, J., Hernández-Lobato, JM. Bayesian Deep Learning via Subnetwork Inference. ICML 2021.

Parameters

model : torch.nn.Module or FeatureExtractor
 
likelihood : {'classification', 'regression'}
determines the log likelihood Hessian approximation
subnetwork_indices : torch.LongTensor
indices of the vectorized model parameters (i.e. torch.nn.utils.parameters_to_vector(model.parameters())) that define the subnetwork to apply the Laplace approximation over
sigma_noise : torch.Tensor or float, default=1
observation noise for the regression setting; must be 1 for classification
prior_precision : torch.Tensor or float, default=1
prior precision of a Gaussian prior (= weight decay); can be scalar, per-layer, or diagonal in the most general case
prior_mean : torch.Tensor or float, default=0
prior mean of a Gaussian prior, useful for continual learning
temperature : float, default=1
temperature of the likelihood; lower temperature leads to more concentrated posterior and vice versa.
backend : subclasses of CurvatureInterface
backend for access to curvature/Hessian approximations
backend_kwargs : dict, default=None
arguments passed to the backend on initialization, for example to set the number of MC samples for stochastic approximations.

Ancestors

Subclasses

Instance variables

var prior_precision_diag

Obtain the diagonal prior precision p_0 constructed from either a scalar or diagonal prior precision.

Returns

prior_precision_diag : torch.Tensor
 
var mean_subnet

Methods

def assemble_full_samples(self, subnet_samples)

Inherited members

class FullSubnetLaplace (model, likelihood, subnetwork_indices, sigma_noise=1.0, prior_precision=1.0, prior_mean=0.0, temperature=1.0, backend=None, backend_kwargs=None)

Subnetwork Laplace approximation with full, i.e., dense, log likelihood Hessian approximation and hence posterior precision. Based on the chosen backend parameter, the full approximation can be, for example, a generalized Gauss-Newton matrix. Mathematically, we have P \in \mathbb{R}^{P \times P}. See FullLaplace, SubnetLaplace, and BaseLaplace for the full interface.

Ancestors

Inherited members

class DiagSubnetLaplace (model, likelihood, subnetwork_indices, sigma_noise=1.0, prior_precision=1.0, prior_mean=0.0, temperature=1.0, backend=None, backend_kwargs=None)

Subnetwork Laplace approximation with diagonal log likelihood Hessian approximation and hence posterior precision. Mathematically, we have P \approx \textrm{diag}(P). See DiagLaplace, SubnetLaplace, and BaseLaplace for the full interface.

Ancestors

Inherited members