Module laplace.subnetlaplace

Classes

class SubnetLaplace (model: nn.Module, likelihood: Likelihood | str, subnetwork_indices: torch.LongTensor, sigma_noise: float | torch.Tensor = 1.0, prior_precision: float | torch.Tensor = 1.0, prior_mean: float | torch.Tensor = 0.0, temperature: float = 1.0, backend: Type[CurvatureInterface] | None = None, backend_kwargs: dict | None = None, asdl_fisher_kwargs: dict | None = None)

Class for subnetwork Laplace, which computes the Laplace approximation over just a subset of the model parameters (i.e. a subnetwork within the neural network), as proposed in [1]. Subnetwork Laplace can only be used with either a full or a diagonal Hessian approximation.

A Laplace approximation is represented by a MAP which is given by the model parameter and a posterior precision or covariance specifying a Gaussian distribution \mathcal{N}(\theta_{MAP}, P^{-1}). Here, only a subset of the model parameters (i.e. a subnetwork of the neural network) are treated probabilistically. The goal of this class is to compute the posterior precision P which sums as P = \sum_{n=1}^N \nabla^2_\theta \log p(\mathcal{D}_n \mid \theta) \vert_{\theta_{MAP}} + \nabla^2_\theta \log p(\theta) \vert_{\theta_{MAP}}. The prior is assumed to be Gaussian and therefore we have a simple form for \nabla^2_\theta \log p(\theta) \vert_{\theta_{MAP}} = P_0 . In particular, we assume a scalar or diagonal prior precision so that in all cases P_0 = \textrm{diag}(p_0) and the structure of p_0 can be varied.

The subnetwork Laplace approximation only supports a full, i.e., dense, log likelihood Hessian approximation and hence posterior precision. Based on the chosen backend parameter, the full approximation can be, for example, a generalized Gauss-Newton matrix. Mathematically, we have P \in \mathbb{R}^{P \times P}. See FullLaplace and BaseLaplace for the full interface.

References

[1] Daxberger, E., Nalisnick, E., Allingham, JU., Antorán, J., Hernández-Lobato, JM. Bayesian Deep Learning via Subnetwork Inference. ICML 2021.

Parameters

model : torch.nn.Module or FeatureExtractor
 
likelihood : {'classification', 'regression'}
determines the log likelihood Hessian approximation
subnetwork_indices : torch.LongTensor
indices of the vectorized model parameters (i.e. torch.nn.utils.parameters_to_vector(model.parameters())) that define the subnetwork to apply the Laplace approximation over
sigma_noise : torch.Tensor or float, default=1
observation noise for the regression setting; must be 1 for classification
prior_precision : torch.Tensor or float, default=1
prior precision of a Gaussian prior (= weight decay); can be scalar, per-layer, or diagonal in the most general case
prior_mean : torch.Tensor or float, default=0
prior mean of a Gaussian prior, useful for continual learning
temperature : float, default=1
temperature of the likelihood; lower temperature leads to more concentrated posterior and vice versa.
backend : subclasses of CurvatureInterface
backend for access to curvature/Hessian approximations
backend_kwargs : dict, default=None
arguments passed to the backend on initialization, for example to set the number of MC samples for stochastic approximations.

Ancestors

Subclasses

Instance variables

var prior_precision_diag : torch.Tensor

Obtain the diagonal prior precision p_0 constructed from either a scalar or diagonal prior precision.

Returns

prior_precision_diag : torch.Tensor
 
var mean_subnet : torch.Tensor

Methods

def assemble_full_samples(self, subnet_samples) ‑> torch.Tensor

Inherited members

class FullSubnetLaplace (model: nn.Module, likelihood: Likelihood | str, subnetwork_indices: torch.LongTensor, sigma_noise: float | torch.Tensor = 1.0, prior_precision: float | torch.Tensor = 1.0, prior_mean: float | torch.Tensor = 0.0, temperature: float = 1.0, backend: Type[CurvatureInterface] | None = None, backend_kwargs: dict | None = None, asdl_fisher_kwargs: dict | None = None)

Subnetwork Laplace approximation with full, i.e., dense, log likelihood Hessian approximation and hence posterior precision. Based on the chosen backend parameter, the full approximation can be, for example, a generalized Gauss-Newton matrix. Mathematically, we have P \in \mathbb{R}^{P \times P}. See FullLaplace, SubnetLaplace, and BaseLaplace for the full interface.

Ancestors

Inherited members

class DiagSubnetLaplace (model: nn.Module, likelihood: Likelihood | str, subnetwork_indices: torch.LongTensor, sigma_noise: float | torch.Tensor = 1.0, prior_precision: float | torch.Tensor = 1.0, prior_mean: float | torch.Tensor = 0.0, temperature: float = 1.0, backend: Type[CurvatureInterface] | None = None, backend_kwargs: dict | None = None, asdl_fisher_kwargs: dict | None = None)

Subnetwork Laplace approximation with diagonal log likelihood Hessian approximation and hence posterior precision. Mathematically, we have P \approx \textrm{diag}(P). See DiagLaplace, SubnetLaplace, and BaseLaplace for the full interface.

Ancestors

Inherited members