Skip to content

laplace.subnetlaplace #

SubnetLaplace #

SubnetLaplace(model: Module, likelihood: Likelihood | str, subnetwork_indices: LongTensor, sigma_noise: float | Tensor = 1.0, prior_precision: float | Tensor = 1.0, prior_mean: float | Tensor = 0.0, temperature: float = 1.0, backend: Type[CurvatureInterface] | None = None, backend_kwargs: dict | None = None, asdl_fisher_kwargs: dict | None = None)

Bases: ParametricLaplace

Class for subnetwork Laplace, which computes the Laplace approximation over just a subset of the model parameters (i.e. a subnetwork within the neural network), as proposed in [1]. Subnetwork Laplace can only be used with either a full or a diagonal Hessian approximation.

A Laplace approximation is represented by a MAP which is given by the model parameter and a posterior precision or covariance specifying a Gaussian distribution \(\mathcal{N}(\theta_{MAP}, P^{-1})\). Here, only a subset of the model parameters (i.e. a subnetwork of the neural network) are treated probabilistically. The goal of this class is to compute the posterior precision \(P\) which sums as

\[ P = \sum_{n=1}^N \nabla^2_\theta \log p(\mathcal{D}_n \mid \theta) \vert_{\theta_{MAP}} + \nabla^2_\theta \log p(\theta) \vert_{\theta_{MAP}}. \]

The prior is assumed to be Gaussian and therefore we have a simple form for \(\nabla^2_\theta \log p(\theta) \vert_{\theta_{MAP}} = P_0 \). In particular, we assume a scalar or diagonal prior precision so that in all cases \(P_0 = \textrm{diag}(p_0)\) and the structure of \(p_0\) can be varied.

The subnetwork Laplace approximation only supports a full, i.e., dense, log likelihood Hessian approximation and hence posterior precision. Based on the chosen backend parameter, the full approximation can be, for example, a generalized Gauss-Newton matrix. Mathematically, we have \(P \in \mathbb{R}^{P \times P}\). See FullLaplace and BaseLaplace for the full interface.

References

[1] Daxberger, E., Nalisnick, E., Allingham, JU., Antorán, J., Hernández-Lobato, JM. Bayesian Deep Learning via Subnetwork Inference. ICML 2021.

Parameters:

  • model (torch.nn.Module or `laplace.utils.feature_extractor.FeatureExtractor`) –
  • likelihood (('classification', 'regression'), default: 'classification' ) –

    determines the log likelihood Hessian approximation

  • subnetwork_indices (LongTensor) –

    indices of the vectorized model parameters (i.e. torch.nn.utils.parameters_to_vector(model.parameters())) that define the subnetwork to apply the Laplace approximation over

  • sigma_noise (Tensor or float, default: 1 ) –

    observation noise for the regression setting; must be 1 for classification

  • prior_precision (Tensor or float, default: 1 ) –

    prior precision of a Gaussian prior (= weight decay); can be scalar, per-layer, or diagonal in the most general case

  • prior_mean (Tensor or float, default: 0 ) –

    prior mean of a Gaussian prior, useful for continual learning

  • temperature (float, default: 1 ) –

    temperature of the likelihood; lower temperature leads to more concentrated posterior and vice versa.

  • backend (subclasses of `laplace.curvature.{GGNInterface,EFInterface}`, default: None ) –

    backend for access to curvature/Hessian approximations

  • backend_kwargs (dict, default: None ) –

    arguments passed to the backend on initialization, for example to set the number of MC samples for stochastic approximations.

Source code in laplace/subnetlaplace.py
def __init__(
    self,
    model: nn.Module,
    likelihood: Likelihood | str,
    subnetwork_indices: torch.LongTensor,
    sigma_noise: float | torch.Tensor = 1.0,
    prior_precision: float | torch.Tensor = 1.0,
    prior_mean: float | torch.Tensor = 0.0,
    temperature: float = 1.0,
    backend: Type[CurvatureInterface] | None = None,
    backend_kwargs: dict | None = None,
    asdl_fisher_kwargs: dict | None = None,
) -> None:
    if asdl_fisher_kwargs is not None:
        raise ValueError("Subnetwork Laplace does not support asdl_fisher_kwargs.")

    self.H = None
    super().__init__(
        model,
        likelihood,
        sigma_noise=sigma_noise,
        prior_precision=prior_precision,
        prior_mean=prior_mean,
        temperature=temperature,
        backend=backend,
        backend_kwargs=backend_kwargs,
    )

    if backend is not None and not issubclass(backend, (GGNInterface, EFInterface)):
        raise ValueError("SubnetLaplace can only be used with GGN and EF.")

    # check validity of subnetwork indices and pass them to backend
    self._check_subnetwork_indices(subnetwork_indices)
    self.backend.subnetwork_indices = subnetwork_indices
    self.n_params_subnet = len(subnetwork_indices)
    self._init_H()

log_likelihood #

log_likelihood: Tensor

Compute log likelihood on the training data after .fit() has been called. The log likelihood is computed on-demand based on the loss and, for example, the observation noise which makes it differentiable in the latter for iterative updates.

Returns:

  • log_likelihood ( Tensor ) –

log_det_prior_precision #

log_det_prior_precision: Tensor

Compute log determinant of the prior precision \(\log \det P_0\)

Returns:

  • log_det ( Tensor ) –

log_det_posterior_precision #

log_det_posterior_precision: Tensor

Compute log determinant of the posterior precision \(\log \det P\) which depends on the subclasses structure used for the Hessian approximation.

Returns:

  • log_det ( Tensor ) –

log_det_ratio #

log_det_ratio: Tensor

Compute the log determinant ratio, a part of the log marginal likelihood.

\[ \log \frac{\det P}{\det P_0} = \log \det P - \log \det P_0 \]

Returns:

  • log_det_ratio ( Tensor ) –

posterior_precision #

posterior_precision: Tensor

Compute or return the posterior precision \(P\).

Returns:

  • posterior_prec ( Tensor ) –

prior_precision_diag #

prior_precision_diag: Tensor

Obtain the diagonal prior precision \(p_0\) constructed from either a scalar or diagonal prior precision.

Returns:

  • prior_precision_diag ( Tensor ) –

fit #

fit(train_loader: DataLoader, override: bool = True, progress_bar: bool = False) -> None

Fit the local Laplace approximation at the parameters of the model.

Parameters:

  • train_loader (DataLoader) –

    each iterate is a training batch, either (X, y) tensors or a dict-like object containing keys as expressed by self.dict_key_x and self.dict_key_y. train_loader.dataset needs to be set to access \(N\), size of the data set.

  • override (bool, default: True ) –

    whether to initialize H, loss, and n_data again; setting to False is useful for online learning settings to accumulate a sequential posterior approximation.

  • progress_bar (bool, default: False ) –

    whether to show a progress bar; updated at every batch-Hessian computation. Useful for very large model and large amount of data, esp. when subset_of_weights='all'.

Source code in laplace/baselaplace.py
def fit(
    self,
    train_loader: DataLoader,
    override: bool = True,
    progress_bar: bool = False,
) -> None:
    """Fit the local Laplace approximation at the parameters of the model.

    Parameters
    ----------
    train_loader : torch.data.utils.DataLoader
        each iterate is a training batch, either `(X, y)` tensors or a dict-like
        object containing keys as expressed by `self.dict_key_x` and
        `self.dict_key_y`. `train_loader.dataset` needs to be set to access
        \\(N\\), size of the data set.
    override : bool, default=True
        whether to initialize H, loss, and n_data again; setting to False is useful for
        online learning settings to accumulate a sequential posterior approximation.
    progress_bar : bool, default=False
        whether to show a progress bar; updated at every batch-Hessian computation.
        Useful for very large model and large amount of data, esp. when `subset_of_weights='all'`.
    """
    if override:
        self._init_H()
        self.loss: float | torch.Tensor = 0
        self.n_data: int = 0

    self.model.eval()

    self.mean: torch.Tensor = parameters_to_vector(self.params)
    if not self.enable_backprop:
        self.mean = self.mean.detach()

    data: (
        tuple[torch.Tensor, torch.Tensor] | MutableMapping[str, torch.Tensor | Any]
    ) = next(iter(train_loader))

    with torch.no_grad():
        if isinstance(data, MutableMapping):  # To support Huggingface dataset
            if "backpack" in self._backend_cls.__name__.lower() or (
                isinstance(self, DiagLaplace) and self._backend_cls == CurvlinopsEF
            ):
                raise ValueError(
                    "Currently DiagEF is not supported under CurvlinopsEF backend "
                    + "for custom models with non-tensor inputs "
                    + "(https://github.com/pytorch/functorch/issues/159). Consider "
                    + "using AsdlEF backend instead. The same limitation applies "
                    + "to all BackPACK backend"
                )

            out = self.model(data)
        else:
            X = data[0]
            try:
                out = self.model(X[:1].to(self._device))
            except (TypeError, AttributeError):
                out = self.model(X.to(self._device))
    self.n_outputs = out.shape[-1]
    setattr(self.model, "output_size", self.n_outputs)

    N = len(train_loader.dataset)

    pbar = tqdm.tqdm(train_loader, disable=not progress_bar)
    pbar.set_description("[Computing Hessian]")

    for data in pbar:
        if isinstance(data, MutableMapping):  # To support Huggingface dataset
            X, y = data, data[self.dict_key_y].to(self._device)
        else:
            X, y = data
            X, y = X.to(self._device), y.to(self._device)

        if self.likelihood == Likelihood.REGRESSION and y.ndim != out.ndim:
            raise ValueError(
                f"The model's output has {out.ndim} dims but "
                f"the target has {y.ndim} dims."
            )

        self.model.zero_grad()
        loss_batch, H_batch = self._curv_closure(X, y, N=N)
        self.loss += loss_batch
        self.H += H_batch

    self.n_data += N

log_marginal_likelihood #

log_marginal_likelihood(prior_precision: Tensor | None = None, sigma_noise: Tensor | None = None) -> Tensor

Compute the Laplace approximation to the log marginal likelihood subject to specific Hessian approximations that subclasses implement. Requires that the Laplace approximation has been fit before. The resulting torch.Tensor is differentiable in prior_precision and sigma_noise if these have gradients enabled. By passing prior_precision or sigma_noise, the current value is overwritten. This is useful for iterating on the log marginal likelihood.

Parameters:

  • prior_precision (Tensor, default: None ) –

    prior precision if should be changed from current prior_precision value

  • sigma_noise (Tensor, default: None ) –

    observation noise standard deviation if should be changed

Returns:

  • log_marglik ( Tensor ) –
Source code in laplace/baselaplace.py
def log_marginal_likelihood(
    self,
    prior_precision: torch.Tensor | None = None,
    sigma_noise: torch.Tensor | None = None,
) -> torch.Tensor:
    """Compute the Laplace approximation to the log marginal likelihood subject
    to specific Hessian approximations that subclasses implement.
    Requires that the Laplace approximation has been fit before.
    The resulting torch.Tensor is differentiable in `prior_precision` and
    `sigma_noise` if these have gradients enabled.
    By passing `prior_precision` or `sigma_noise`, the current value is
    overwritten. This is useful for iterating on the log marginal likelihood.

    Parameters
    ----------
    prior_precision : torch.Tensor, optional
        prior precision if should be changed from current `prior_precision` value
    sigma_noise : torch.Tensor, optional
        observation noise standard deviation if should be changed

    Returns
    -------
    log_marglik : torch.Tensor
    """
    # update prior precision (useful when iterating on marglik)
    if prior_precision is not None:
        self.prior_precision = prior_precision

    # update sigma_noise (useful when iterating on marglik)
    if sigma_noise is not None:
        if self.likelihood != Likelihood.REGRESSION:
            raise ValueError("Can only change sigma_noise for regression.")

        self.sigma_noise = sigma_noise

    return self.log_likelihood - 0.5 * (self.log_det_ratio + self.scatter)

__call__ #

__call__(x: Tensor | MutableMapping[str, Tensor | Any], pred_type: PredType | str = PredType.GLM, joint: bool = False, link_approx: LinkApprox | str = LinkApprox.PROBIT, n_samples: int = 100, diagonal_output: bool = False, generator: Generator | None = None, fitting: bool = False, **model_kwargs: dict[str, Any]) -> Tensor | tuple[Tensor, Tensor]

Compute the posterior predictive on input data x.

Parameters:

  • x (Tensor or MutableMapping) –

    (batch_size, input_shape) if tensor. If MutableMapping, must contain the said tensor.

  • pred_type (('glm', 'nn'), default: 'glm' ) –

    type of posterior predictive, linearized GLM predictive or neural network sampling predictive. The GLM predictive is consistent with the curvature approximations used here. When Laplace is done only on subset of parameters (i.e. some grad are disabled), only nn predictive is supported.

  • link_approx (('mc', 'probit', 'bridge', 'bridge_norm'), default: 'mc' ) –

    how to approximate the classification link function for the 'glm'. For pred_type='nn', only 'mc' is possible.

  • joint (bool, default: False ) –

    Whether to output a joint predictive distribution in regression with pred_type='glm'. If set to True, the predictive distribution has the same form as GP posterior, i.e. N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]). If False, then only outputs the marginal predictive distribution. Only available for regression and GLM predictive.

  • n_samples (int, default: 100 ) –

    number of samples for link_approx='mc'.

  • diagonal_output (bool, default: False ) –

    whether to use a diagonalized posterior predictive on the outputs. Only works for pred_type='glm' when joint=False in regression. In the case of last-layer Laplace with a diagonal or Kron Hessian, setting this to True makes computation much(!) faster for large number of outputs.

  • generator (Generator, default: None ) –

    random number generator to control the samples (if sampling used).

  • fitting (bool, default: False ) –

    whether or not this predictive call is done during fitting. Only useful for reward modeling: the likelihood is set to "regression" when False and "classification" when True.

Returns:

  • predictive ( Tensor or tuple[Tensor] ) –

    For likelihood='classification', a torch.Tensor is returned with a distribution over classes (similar to a Softmax). For likelihood='regression', a tuple of torch.Tensor is returned with the mean and the predictive variance. For likelihood='regression' and joint=True, a tuple of torch.Tensor is returned with the mean and the predictive covariance.

Source code in laplace/baselaplace.py
def __call__(
    self,
    x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
    pred_type: PredType | str = PredType.GLM,
    joint: bool = False,
    link_approx: LinkApprox | str = LinkApprox.PROBIT,
    n_samples: int = 100,
    diagonal_output: bool = False,
    generator: torch.Generator | None = None,
    fitting: bool = False,
    **model_kwargs: dict[str, Any],
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
    """Compute the posterior predictive on input data `x`.

    Parameters
    ----------
    x : torch.Tensor or MutableMapping
        `(batch_size, input_shape)` if tensor. If MutableMapping, must contain
        the said tensor.

    pred_type : {'glm', 'nn'}, default='glm'
        type of posterior predictive, linearized GLM predictive or neural
        network sampling predictive. The GLM predictive is consistent with
        the curvature approximations used here. When Laplace is done only
        on subset of parameters (i.e. some grad are disabled),
        only `nn` predictive is supported.

    link_approx : {'mc', 'probit', 'bridge', 'bridge_norm'}
        how to approximate the classification link function for the `'glm'`.
        For `pred_type='nn'`, only 'mc' is possible.

    joint : bool
        Whether to output a joint predictive distribution in regression with
        `pred_type='glm'`. If set to `True`, the predictive distribution
        has the same form as GP posterior, i.e. N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]).
        If `False`, then only outputs the marginal predictive distribution.
        Only available for regression and GLM predictive.

    n_samples : int
        number of samples for `link_approx='mc'`.

    diagonal_output : bool
        whether to use a diagonalized posterior predictive on the outputs.
        Only works for `pred_type='glm'` when `joint=False` in regression.
        In the case of last-layer Laplace with a diagonal or Kron Hessian,
        setting this to `True` makes computation much(!) faster for large
        number of outputs.

    generator : torch.Generator, optional
        random number generator to control the samples (if sampling used).

    fitting : bool, default=False
        whether or not this predictive call is done during fitting. Only useful for
        reward modeling: the likelihood is set to `"regression"` when `False` and
        `"classification"` when `True`.

    Returns
    -------
    predictive: torch.Tensor or tuple[torch.Tensor]
        For `likelihood='classification'`, a torch.Tensor is returned with
        a distribution over classes (similar to a Softmax).
        For `likelihood='regression'`, a tuple of torch.Tensor is returned
        with the mean and the predictive variance.
        For `likelihood='regression'` and `joint=True`, a tuple of torch.Tensor
        is returned with the mean and the predictive covariance.
    """
    if pred_type not in [pred for pred in PredType]:
        raise ValueError("Only glm and nn supported as prediction types.")

    if link_approx not in [la for la in LinkApprox]:
        raise ValueError(f"Unsupported link approximation {link_approx}.")

    if pred_type == PredType.NN and link_approx != LinkApprox.MC:
        raise ValueError(
            "Only mc link approximation is supported for nn prediction type."
        )

    if generator is not None:
        if (
            not isinstance(generator, torch.Generator)
            or generator.device != self._device
        ):
            raise ValueError("Invalid random generator (check type and device).")

    likelihood = self.likelihood
    if likelihood == Likelihood.REWARD_MODELING:
        likelihood = Likelihood.CLASSIFICATION if fitting else Likelihood.REGRESSION

    if pred_type == PredType.GLM:
        return self._glm_forward_call(
            x, likelihood, joint, link_approx, n_samples, diagonal_output
        )
    else:
        if likelihood == Likelihood.REGRESSION:
            samples = self._nn_predictive_samples(x, n_samples, **model_kwargs)
            return samples.mean(dim=0), samples.var(dim=0)
        else:  # classification; the average is computed online
            return self._nn_predictive_classification(x, n_samples, **model_kwargs)

_glm_forward_call #

_glm_forward_call(x: Tensor | MutableMapping, likelihood: Likelihood | str, joint: bool = False, link_approx: LinkApprox | str = LinkApprox.PROBIT, n_samples: int = 100, diagonal_output: bool = False) -> Tensor | tuple[Tensor, Tensor]

Compute the posterior predictive on input data x for "glm" pred type.

Parameters:

  • x (Tensor or MutableMapping) –

    (batch_size, input_shape) if tensor. If MutableMapping, must contain the said tensor.

  • likelihood (Likelihood or str in {'classification', 'regression', 'reward_modeling'}) –

    determines the log likelihood Hessian approximation.

  • link_approx (('mc', 'probit', 'bridge', 'bridge_norm'), default: 'mc' ) –

    how to approximate the classification link function for the 'glm'. For pred_type='nn', only 'mc' is possible.

  • joint (bool, default: False ) –

    Whether to output a joint predictive distribution in regression with pred_type='glm'. If set to True, the predictive distribution has the same form as GP posterior, i.e. N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]). If False, then only outputs the marginal predictive distribution. Only available for regression and GLM predictive.

  • n_samples (int, default: 100 ) –

    number of samples for link_approx='mc'.

  • diagonal_output (bool, default: False ) –

    whether to use a diagonalized posterior predictive on the outputs. Only works for pred_type='glm' and link_approx='mc'.

Returns:

  • predictive ( Tensor or tuple[Tensor] ) –

    For likelihood='classification', a torch.Tensor is returned with a distribution over classes (similar to a Softmax). For likelihood='regression', a tuple of torch.Tensor is returned with the mean and the predictive variance. For likelihood='regression' and joint=True, a tuple of torch.Tensor is returned with the mean and the predictive covariance.

Source code in laplace/baselaplace.py
def _glm_forward_call(
    self,
    x: torch.Tensor | MutableMapping,
    likelihood: Likelihood | str,
    joint: bool = False,
    link_approx: LinkApprox | str = LinkApprox.PROBIT,
    n_samples: int = 100,
    diagonal_output: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
    """Compute the posterior predictive on input data `x` for "glm" pred type.

    Parameters
    ----------
    x : torch.Tensor or MutableMapping
        `(batch_size, input_shape)` if tensor. If MutableMapping, must contain
        the said tensor.

    likelihood : Likelihood or str in {'classification', 'regression', 'reward_modeling'}
        determines the log likelihood Hessian approximation.

    link_approx : {'mc', 'probit', 'bridge', 'bridge_norm'}
        how to approximate the classification link function for the `'glm'`.
        For `pred_type='nn'`, only 'mc' is possible.

    joint : bool
        Whether to output a joint predictive distribution in regression with
        `pred_type='glm'`. If set to `True`, the predictive distribution
        has the same form as GP posterior, i.e. N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]).
        If `False`, then only outputs the marginal predictive distribution.
        Only available for regression and GLM predictive.

    n_samples : int
        number of samples for `link_approx='mc'`.

    diagonal_output : bool
        whether to use a diagonalized posterior predictive on the outputs.
        Only works for `pred_type='glm'` and `link_approx='mc'`.

    Returns
    -------
    predictive: torch.Tensor or tuple[torch.Tensor]
        For `likelihood='classification'`, a torch.Tensor is returned with
        a distribution over classes (similar to a Softmax).
        For `likelihood='regression'`, a tuple of torch.Tensor is returned
        with the mean and the predictive variance.
        For `likelihood='regression'` and `joint=True`, a tuple of torch.Tensor
        is returned with the mean and the predictive covariance.
    """
    f_mu, f_var = self._glm_predictive_distribution(
        x, joint=joint and likelihood == Likelihood.REGRESSION
    )

    if likelihood == Likelihood.REGRESSION:
        if diagonal_output and not joint:
            f_var = torch.diagonal(f_var, dim1=-2, dim2=-1)
        return f_mu, f_var

    if link_approx == LinkApprox.MC:
        return self._glm_predictive_samples(
            f_mu,
            f_var,
            n_samples=n_samples,
            diagonal_output=diagonal_output,
        ).mean(dim=0)
    elif link_approx == LinkApprox.PROBIT:
        kappa = 1 / torch.sqrt(1.0 + np.pi / 8 * f_var.diagonal(dim1=1, dim2=2))
        return torch.softmax(kappa * f_mu, dim=-1)
    elif "bridge" in link_approx:
        # zero mean correction
        f_mu -= (
            f_var.sum(-1)
            * f_mu.sum(-1).reshape(-1, 1)
            / f_var.sum(dim=(1, 2)).reshape(-1, 1)
        )
        f_var -= torch.einsum(
            "bi,bj->bij", f_var.sum(-1), f_var.sum(-2)
        ) / f_var.sum(dim=(1, 2)).reshape(-1, 1, 1)

        # Laplace Bridge
        _, K = f_mu.size(0), f_mu.size(-1)
        f_var_diag = torch.diagonal(f_var, dim1=1, dim2=2)

        # optional: variance correction
        if link_approx == LinkApprox.BRIDGE_NORM:
            f_var_diag_mean = f_var_diag.mean(dim=1)
            f_var_diag_mean /= torch.as_tensor([K / 2], device=self._device).sqrt()
            f_mu /= f_var_diag_mean.sqrt().unsqueeze(-1)
            f_var_diag /= f_var_diag_mean.unsqueeze(-1)

        sum_exp = torch.exp(-f_mu).sum(dim=1).unsqueeze(-1)
        alpha = (1 - 2 / K + f_mu.exp() / K**2 * sum_exp) / f_var_diag
        return torch.nan_to_num(alpha / alpha.sum(dim=1).unsqueeze(-1), nan=1.0)
    else:
        raise ValueError(
            "Prediction path invalid. Check the likelihood, pred_type, link_approx combination!"
        )

_glm_predictive_samples #

_glm_predictive_samples(f_mu: Tensor, f_var: Tensor, n_samples: int, diagonal_output: bool = False, generator: Generator | None = None) -> Tensor

Sample from the posterior predictive on input data x using "glm" prediction type.

Parameters:

  • f_mu (Tensor or MutableMapping) –

    glm predictive mean (batch_size, output_shape)

  • f_var (Tensor or MutableMapping) –

    glm predictive covariances (batch_size, output_shape, output_shape)

  • n_samples (int) –

    number of samples

  • diagonal_output (bool, default: False ) –

    whether to use a diagonalized glm posterior predictive on the outputs.

  • generator (Generator, default: None ) –

    random number generator to control the samples (if sampling used)

Returns:

  • samples ( Tensor ) –

    samples (n_samples, batch_size, output_shape)

Source code in laplace/baselaplace.py
def _glm_predictive_samples(
    self,
    f_mu: torch.Tensor,
    f_var: torch.Tensor,
    n_samples: int,
    diagonal_output: bool = False,
    generator: torch.Generator | None = None,
) -> torch.Tensor:
    """Sample from the posterior predictive on input data `x` using "glm" prediction
    type.

    Parameters
    ----------
    f_mu : torch.Tensor or MutableMapping
        glm predictive mean `(batch_size, output_shape)`

    f_var : torch.Tensor or MutableMapping
        glm predictive covariances `(batch_size, output_shape, output_shape)`

    n_samples : int
        number of samples

    diagonal_output : bool
        whether to use a diagonalized glm posterior predictive on the outputs.

    generator : torch.Generator, optional
        random number generator to control the samples (if sampling used)

    Returns
    -------
    samples : torch.Tensor
        samples `(n_samples, batch_size, output_shape)`
    """
    assert f_var.shape == torch.Size([f_mu.shape[0], f_mu.shape[1], f_mu.shape[1]])

    if diagonal_output:
        f_var = torch.diagonal(f_var, dim1=1, dim2=2)

    f_samples = normal_samples(f_mu, f_var, n_samples, generator)

    if self.likelihood == Likelihood.REGRESSION:
        return f_samples
    else:
        return torch.softmax(f_samples, dim=-1)

square_norm #

square_norm(value) -> Tensor

Compute the square norm under post. Precision with value-self.mean as 𝛥:

\[ \Delta^ op P \Delta \]

Returns:

  • square_form
Source code in laplace/baselaplace.py
def square_norm(self, value) -> torch.Tensor:
    """Compute the square norm under post. Precision with `value-self.mean` as 𝛥:

    $$
        \\Delta^\top P \\Delta
    $$

    Returns
    -------
    square_form
    """
    raise NotImplementedError

log_prob #

log_prob(value: Tensor, normalized: bool = True) -> Tensor

Compute the log probability under the (current) Laplace approximation.

Parameters:

  • value (Tensor) –
  • normalized (bool, default: True ) –

    whether to return log of a properly normalized Gaussian or just the terms that depend on value.

Returns:

  • log_prob ( Tensor ) –
Source code in laplace/baselaplace.py
def log_prob(self, value: torch.Tensor, normalized: bool = True) -> torch.Tensor:
    """Compute the log probability under the (current) Laplace approximation.

    Parameters
    ----------
    value: torch.Tensor
    normalized : bool, default=True
        whether to return log of a properly normalized Gaussian or just the
        terms that depend on `value`.

    Returns
    -------
    log_prob : torch.Tensor
    """
    if not normalized:
        return -self.square_norm(value) / 2
    log_prob = (
        -self.n_params / 2 * log(2 * pi) + self.log_det_posterior_precision / 2
    )
    log_prob -= self.square_norm(value) / 2
    return log_prob

predictive_samples #

predictive_samples(x: Tensor | MutableMapping[str, Tensor | Any], pred_type: PredType | str = PredType.GLM, n_samples: int = 100, diagonal_output: bool = False, generator: Generator | None = None) -> Tensor

Sample from the posterior predictive on input data x. Can be used, for example, for Thompson sampling.

Parameters:

  • x (Tensor or MutableMapping) –

    input data (batch_size, input_shape)

  • pred_type (('glm', 'nn'), default: 'glm' ) –

    type of posterior predictive, linearized GLM predictive or neural network sampling predictive. The GLM predictive is consistent with the curvature approximations used here.

  • n_samples (int, default: 100 ) –

    number of samples

  • diagonal_output (bool, default: False ) –

    whether to use a diagonalized glm posterior predictive on the outputs. Only applies when pred_type='glm'.

  • generator (Generator, default: None ) –

    random number generator to control the samples (if sampling used)

Returns:

  • samples ( Tensor ) –

    samples (n_samples, batch_size, output_shape)

Source code in laplace/baselaplace.py
def predictive_samples(
    self,
    x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
    pred_type: PredType | str = PredType.GLM,
    n_samples: int = 100,
    diagonal_output: bool = False,
    generator: torch.Generator | None = None,
) -> torch.Tensor:
    """Sample from the posterior predictive on input data `x`.
    Can be used, for example, for Thompson sampling.

    Parameters
    ----------
    x : torch.Tensor or MutableMapping
        input data `(batch_size, input_shape)`

    pred_type : {'glm', 'nn'}, default='glm'
        type of posterior predictive, linearized GLM predictive or neural
        network sampling predictive. The GLM predictive is consistent with
        the curvature approximations used here.

    n_samples : int
        number of samples

    diagonal_output : bool
        whether to use a diagonalized glm posterior predictive on the outputs.
        Only applies when `pred_type='glm'`.

    generator : torch.Generator, optional
        random number generator to control the samples (if sampling used)

    Returns
    -------
    samples : torch.Tensor
        samples `(n_samples, batch_size, output_shape)`
    """
    if pred_type not in PredType.__members__.values():
        raise ValueError("Only glm and nn supported as prediction types.")

    if pred_type == PredType.GLM:
        f_mu, f_var = self._glm_predictive_distribution(x)
        return self._glm_predictive_samples(
            f_mu, f_var, n_samples, diagonal_output, generator
        )

    else:  # 'nn'
        return self._nn_predictive_samples(x, n_samples, generator)

functional_variance #

functional_variance(Js: Tensor) -> Tensor

Compute functional variance for the 'glm' predictive: f_var[i] = Js[i] @ P.inv() @ Js[i].T, which is a output x output predictive covariance matrix. Mathematically, we have for a single Jacobian \(\mathcal{J} = \nabla_\theta f(x;\theta)\vert_{\theta_{MAP}}\) the output covariance matrix \( \mathcal{J} P^{-1} \mathcal{J}^T \).

Parameters:

  • Js (Tensor) –

    Jacobians of model output wrt parameters (batch, outputs, parameters)

Returns:

  • f_var ( Tensor ) –

    output covariance (batch, outputs, outputs)

Source code in laplace/baselaplace.py
def functional_variance(self, Js: torch.Tensor) -> torch.Tensor:
    """Compute functional variance for the `'glm'` predictive:
    `f_var[i] = Js[i] @ P.inv() @ Js[i].T`, which is a output x output
    predictive covariance matrix.
    Mathematically, we have for a single Jacobian
    \\(\\mathcal{J} = \\nabla_\\theta f(x;\\theta)\\vert_{\\theta_{MAP}}\\)
    the output covariance matrix
    \\( \\mathcal{J} P^{-1} \\mathcal{J}^T \\).

    Parameters
    ----------
    Js : torch.Tensor
        Jacobians of model output wrt parameters
        `(batch, outputs, parameters)`

    Returns
    -------
    f_var : torch.Tensor
        output covariance `(batch, outputs, outputs)`
    """
    raise NotImplementedError

functional_covariance #

functional_covariance(Js: Tensor) -> Tensor

Compute functional covariance for the 'glm' predictive: f_cov = Js @ P.inv() @ Js.T, which is a batchoutput x batchoutput predictive covariance matrix.

This emulates the GP posterior covariance N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]). Useful for joint predictions, such as in batched Bayesian optimization.

Parameters:

  • Js (Tensor) –

    Jacobians of model output wrt parameters (batch*outputs, parameters)

Returns:

  • f_cov ( Tensor ) –

    output covariance (batch*outputs, batch*outputs)

Source code in laplace/baselaplace.py
def functional_covariance(self, Js: torch.Tensor) -> torch.Tensor:
    """Compute functional covariance for the `'glm'` predictive:
    `f_cov = Js @ P.inv() @ Js.T`, which is a batch*output x batch*output
    predictive covariance matrix.

    This emulates the GP posterior covariance N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]).
    Useful for joint predictions, such as in batched Bayesian optimization.

    Parameters
    ----------
    Js : torch.Tensor
        Jacobians of model output wrt parameters
        `(batch*outputs, parameters)`

    Returns
    -------
    f_cov : torch.Tensor
        output covariance `(batch*outputs, batch*outputs)`
    """
    raise NotImplementedError

sample #

sample(n_samples: int = 100, generator: Generator | None = None) -> Tensor

Sample from the Laplace posterior approximation, i.e., \( \theta \sim \mathcal{N}(\theta_{MAP}, P^{-1})\).

Parameters:

  • n_samples (int, default: 100 ) –

    number of samples

  • generator (Generator, default: None ) –

    random number generator to control the samples

Returns:

  • samples ( Tensor ) –
Source code in laplace/baselaplace.py
def sample(
    self, n_samples: int = 100, generator: torch.Generator | None = None
) -> torch.Tensor:
    """Sample from the Laplace posterior approximation, i.e.,
    \\( \\theta \\sim \\mathcal{N}(\\theta_{MAP}, P^{-1})\\).

    Parameters
    ----------
    n_samples : int, default=100
        number of samples

    generator : torch.Generator, optional
        random number generator to control the samples

    Returns
    -------
    samples: torch.Tensor
    """
    raise NotImplementedError

_check_subnetwork_indices #

_check_subnetwork_indices(subnetwork_indices: LongTensor | None) -> None

Check that subnetwork indices are valid indices of the vectorized model parameters (i.e. torch.nn.utils.parameters_to_vector(model.parameters())).

Source code in laplace/subnetlaplace.py
def _check_subnetwork_indices(
    self, subnetwork_indices: torch.LongTensor | None
) -> None:
    """Check that subnetwork indices are valid indices of the vectorized model parameters
    (i.e. `torch.nn.utils.parameters_to_vector(model.parameters())`).
    """
    if subnetwork_indices is None:
        raise ValueError("Subnetwork indices cannot be None.")
    elif not (
        isinstance(subnetwork_indices, torch.LongTensor)
        and subnetwork_indices.numel() > 0
        and len(subnetwork_indices.shape) == 1
    ):
        raise ValueError(
            "Subnetwork indices must be non-empty 1-dimensional torch.LongTensor."
        )
    elif not (
        len(subnetwork_indices[subnetwork_indices < 0]) == 0
        and len(subnetwork_indices[subnetwork_indices >= self.n_params]) == 0
    ):
        raise ValueError(
            f"Subnetwork indices must lie between 0 and n_params={self.n_params}."
        )
    elif not (len(subnetwork_indices.unique()) == len(subnetwork_indices)):
        raise ValueError("Subnetwork indices must not contain duplicate entries.")

DiagSubnetLaplace #

DiagSubnetLaplace(model: Module, likelihood: Likelihood | str, subnetwork_indices: LongTensor, sigma_noise: float | Tensor = 1.0, prior_precision: float | Tensor = 1.0, prior_mean: float | Tensor = 0.0, temperature: float = 1.0, backend: Type[CurvatureInterface] | None = None, backend_kwargs: dict | None = None, asdl_fisher_kwargs: dict | None = None)

Bases: SubnetLaplace, DiagLaplace

Subnetwork Laplace approximation with diagonal log likelihood Hessian approximation and hence posterior precision. Mathematically, we have \(P \approx \textrm{diag}(P)\). See DiagLaplace, SubnetLaplace, and BaseLaplace for the full interface.

Source code in laplace/subnetlaplace.py
def __init__(
    self,
    model: nn.Module,
    likelihood: Likelihood | str,
    subnetwork_indices: torch.LongTensor,
    sigma_noise: float | torch.Tensor = 1.0,
    prior_precision: float | torch.Tensor = 1.0,
    prior_mean: float | torch.Tensor = 0.0,
    temperature: float = 1.0,
    backend: Type[CurvatureInterface] | None = None,
    backend_kwargs: dict | None = None,
    asdl_fisher_kwargs: dict | None = None,
) -> None:
    if asdl_fisher_kwargs is not None:
        raise ValueError("Subnetwork Laplace does not support asdl_fisher_kwargs.")

    self.H = None
    super().__init__(
        model,
        likelihood,
        sigma_noise=sigma_noise,
        prior_precision=prior_precision,
        prior_mean=prior_mean,
        temperature=temperature,
        backend=backend,
        backend_kwargs=backend_kwargs,
    )

    if backend is not None and not issubclass(backend, (GGNInterface, EFInterface)):
        raise ValueError("SubnetLaplace can only be used with GGN and EF.")

    # check validity of subnetwork indices and pass them to backend
    self._check_subnetwork_indices(subnetwork_indices)
    self.backend.subnetwork_indices = subnetwork_indices
    self.n_params_subnet = len(subnetwork_indices)
    self._init_H()

log_likelihood #

log_likelihood: Tensor

Compute log likelihood on the training data after .fit() has been called. The log likelihood is computed on-demand based on the loss and, for example, the observation noise which makes it differentiable in the latter for iterative updates.

Returns:

  • log_likelihood ( Tensor ) –

prior_precision_diag #

prior_precision_diag: Tensor

Obtain the diagonal prior precision \(p_0\) constructed from either a scalar or diagonal prior precision.

Returns:

  • prior_precision_diag ( Tensor ) –

log_det_prior_precision #

log_det_prior_precision: Tensor

Compute log determinant of the prior precision \(\log \det P_0\)

Returns:

  • log_det ( Tensor ) –

log_det_ratio #

log_det_ratio: Tensor

Compute the log determinant ratio, a part of the log marginal likelihood.

\[ \log \frac{\det P}{\det P_0} = \log \det P - \log \det P_0 \]

Returns:

  • log_det_ratio ( Tensor ) –

posterior_precision #

posterior_precision: Tensor

Diagonal posterior precision \(p\).

Returns:

  • precision ( tensor ) –

    (parameters)

posterior_scale #

posterior_scale: Tensor

Diagonal posterior scale \(\sqrt{p^{-1}}\).

Returns:

  • precision ( tensor ) –

    (parameters)

posterior_variance #

posterior_variance: Tensor

Diagonal posterior variance \(p^{-1}\).

Returns:

  • precision ( tensor ) –

    (parameters)

fit #

fit(train_loader: DataLoader, override: bool = True, progress_bar: bool = False) -> None

Fit the local Laplace approximation at the parameters of the model.

Parameters:

  • train_loader (DataLoader) –

    each iterate is a training batch, either (X, y) tensors or a dict-like object containing keys as expressed by self.dict_key_x and self.dict_key_y. train_loader.dataset needs to be set to access \(N\), size of the data set.

  • override (bool, default: True ) –

    whether to initialize H, loss, and n_data again; setting to False is useful for online learning settings to accumulate a sequential posterior approximation.

  • progress_bar (bool, default: False ) –

    whether to show a progress bar; updated at every batch-Hessian computation. Useful for very large model and large amount of data, esp. when subset_of_weights='all'.

Source code in laplace/baselaplace.py
def fit(
    self,
    train_loader: DataLoader,
    override: bool = True,
    progress_bar: bool = False,
) -> None:
    """Fit the local Laplace approximation at the parameters of the model.

    Parameters
    ----------
    train_loader : torch.data.utils.DataLoader
        each iterate is a training batch, either `(X, y)` tensors or a dict-like
        object containing keys as expressed by `self.dict_key_x` and
        `self.dict_key_y`. `train_loader.dataset` needs to be set to access
        \\(N\\), size of the data set.
    override : bool, default=True
        whether to initialize H, loss, and n_data again; setting to False is useful for
        online learning settings to accumulate a sequential posterior approximation.
    progress_bar : bool, default=False
        whether to show a progress bar; updated at every batch-Hessian computation.
        Useful for very large model and large amount of data, esp. when `subset_of_weights='all'`.
    """
    if override:
        self._init_H()
        self.loss: float | torch.Tensor = 0
        self.n_data: int = 0

    self.model.eval()

    self.mean: torch.Tensor = parameters_to_vector(self.params)
    if not self.enable_backprop:
        self.mean = self.mean.detach()

    data: (
        tuple[torch.Tensor, torch.Tensor] | MutableMapping[str, torch.Tensor | Any]
    ) = next(iter(train_loader))

    with torch.no_grad():
        if isinstance(data, MutableMapping):  # To support Huggingface dataset
            if "backpack" in self._backend_cls.__name__.lower() or (
                isinstance(self, DiagLaplace) and self._backend_cls == CurvlinopsEF
            ):
                raise ValueError(
                    "Currently DiagEF is not supported under CurvlinopsEF backend "
                    + "for custom models with non-tensor inputs "
                    + "(https://github.com/pytorch/functorch/issues/159). Consider "
                    + "using AsdlEF backend instead. The same limitation applies "
                    + "to all BackPACK backend"
                )

            out = self.model(data)
        else:
            X = data[0]
            try:
                out = self.model(X[:1].to(self._device))
            except (TypeError, AttributeError):
                out = self.model(X.to(self._device))
    self.n_outputs = out.shape[-1]
    setattr(self.model, "output_size", self.n_outputs)

    N = len(train_loader.dataset)

    pbar = tqdm.tqdm(train_loader, disable=not progress_bar)
    pbar.set_description("[Computing Hessian]")

    for data in pbar:
        if isinstance(data, MutableMapping):  # To support Huggingface dataset
            X, y = data, data[self.dict_key_y].to(self._device)
        else:
            X, y = data
            X, y = X.to(self._device), y.to(self._device)

        if self.likelihood == Likelihood.REGRESSION and y.ndim != out.ndim:
            raise ValueError(
                f"The model's output has {out.ndim} dims but "
                f"the target has {y.ndim} dims."
            )

        self.model.zero_grad()
        loss_batch, H_batch = self._curv_closure(X, y, N=N)
        self.loss += loss_batch
        self.H += H_batch

    self.n_data += N

log_marginal_likelihood #

log_marginal_likelihood(prior_precision: Tensor | None = None, sigma_noise: Tensor | None = None) -> Tensor

Compute the Laplace approximation to the log marginal likelihood subject to specific Hessian approximations that subclasses implement. Requires that the Laplace approximation has been fit before. The resulting torch.Tensor is differentiable in prior_precision and sigma_noise if these have gradients enabled. By passing prior_precision or sigma_noise, the current value is overwritten. This is useful for iterating on the log marginal likelihood.

Parameters:

  • prior_precision (Tensor, default: None ) –

    prior precision if should be changed from current prior_precision value

  • sigma_noise (Tensor, default: None ) –

    observation noise standard deviation if should be changed

Returns:

  • log_marglik ( Tensor ) –
Source code in laplace/baselaplace.py
def log_marginal_likelihood(
    self,
    prior_precision: torch.Tensor | None = None,
    sigma_noise: torch.Tensor | None = None,
) -> torch.Tensor:
    """Compute the Laplace approximation to the log marginal likelihood subject
    to specific Hessian approximations that subclasses implement.
    Requires that the Laplace approximation has been fit before.
    The resulting torch.Tensor is differentiable in `prior_precision` and
    `sigma_noise` if these have gradients enabled.
    By passing `prior_precision` or `sigma_noise`, the current value is
    overwritten. This is useful for iterating on the log marginal likelihood.

    Parameters
    ----------
    prior_precision : torch.Tensor, optional
        prior precision if should be changed from current `prior_precision` value
    sigma_noise : torch.Tensor, optional
        observation noise standard deviation if should be changed

    Returns
    -------
    log_marglik : torch.Tensor
    """
    # update prior precision (useful when iterating on marglik)
    if prior_precision is not None:
        self.prior_precision = prior_precision

    # update sigma_noise (useful when iterating on marglik)
    if sigma_noise is not None:
        if self.likelihood != Likelihood.REGRESSION:
            raise ValueError("Can only change sigma_noise for regression.")

        self.sigma_noise = sigma_noise

    return self.log_likelihood - 0.5 * (self.log_det_ratio + self.scatter)

__call__ #

__call__(x: Tensor | MutableMapping[str, Tensor | Any], pred_type: PredType | str = PredType.GLM, joint: bool = False, link_approx: LinkApprox | str = LinkApprox.PROBIT, n_samples: int = 100, diagonal_output: bool = False, generator: Generator | None = None, fitting: bool = False, **model_kwargs: dict[str, Any]) -> Tensor | tuple[Tensor, Tensor]

Compute the posterior predictive on input data x.

Parameters:

  • x (Tensor or MutableMapping) –

    (batch_size, input_shape) if tensor. If MutableMapping, must contain the said tensor.

  • pred_type (('glm', 'nn'), default: 'glm' ) –

    type of posterior predictive, linearized GLM predictive or neural network sampling predictive. The GLM predictive is consistent with the curvature approximations used here. When Laplace is done only on subset of parameters (i.e. some grad are disabled), only nn predictive is supported.

  • link_approx (('mc', 'probit', 'bridge', 'bridge_norm'), default: 'mc' ) –

    how to approximate the classification link function for the 'glm'. For pred_type='nn', only 'mc' is possible.

  • joint (bool, default: False ) –

    Whether to output a joint predictive distribution in regression with pred_type='glm'. If set to True, the predictive distribution has the same form as GP posterior, i.e. N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]). If False, then only outputs the marginal predictive distribution. Only available for regression and GLM predictive.

  • n_samples (int, default: 100 ) –

    number of samples for link_approx='mc'.

  • diagonal_output (bool, default: False ) –

    whether to use a diagonalized posterior predictive on the outputs. Only works for pred_type='glm' when joint=False in regression. In the case of last-layer Laplace with a diagonal or Kron Hessian, setting this to True makes computation much(!) faster for large number of outputs.

  • generator (Generator, default: None ) –

    random number generator to control the samples (if sampling used).

  • fitting (bool, default: False ) –

    whether or not this predictive call is done during fitting. Only useful for reward modeling: the likelihood is set to "regression" when False and "classification" when True.

Returns:

  • predictive ( Tensor or tuple[Tensor] ) –

    For likelihood='classification', a torch.Tensor is returned with a distribution over classes (similar to a Softmax). For likelihood='regression', a tuple of torch.Tensor is returned with the mean and the predictive variance. For likelihood='regression' and joint=True, a tuple of torch.Tensor is returned with the mean and the predictive covariance.

Source code in laplace/baselaplace.py
def __call__(
    self,
    x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
    pred_type: PredType | str = PredType.GLM,
    joint: bool = False,
    link_approx: LinkApprox | str = LinkApprox.PROBIT,
    n_samples: int = 100,
    diagonal_output: bool = False,
    generator: torch.Generator | None = None,
    fitting: bool = False,
    **model_kwargs: dict[str, Any],
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
    """Compute the posterior predictive on input data `x`.

    Parameters
    ----------
    x : torch.Tensor or MutableMapping
        `(batch_size, input_shape)` if tensor. If MutableMapping, must contain
        the said tensor.

    pred_type : {'glm', 'nn'}, default='glm'
        type of posterior predictive, linearized GLM predictive or neural
        network sampling predictive. The GLM predictive is consistent with
        the curvature approximations used here. When Laplace is done only
        on subset of parameters (i.e. some grad are disabled),
        only `nn` predictive is supported.

    link_approx : {'mc', 'probit', 'bridge', 'bridge_norm'}
        how to approximate the classification link function for the `'glm'`.
        For `pred_type='nn'`, only 'mc' is possible.

    joint : bool
        Whether to output a joint predictive distribution in regression with
        `pred_type='glm'`. If set to `True`, the predictive distribution
        has the same form as GP posterior, i.e. N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]).
        If `False`, then only outputs the marginal predictive distribution.
        Only available for regression and GLM predictive.

    n_samples : int
        number of samples for `link_approx='mc'`.

    diagonal_output : bool
        whether to use a diagonalized posterior predictive on the outputs.
        Only works for `pred_type='glm'` when `joint=False` in regression.
        In the case of last-layer Laplace with a diagonal or Kron Hessian,
        setting this to `True` makes computation much(!) faster for large
        number of outputs.

    generator : torch.Generator, optional
        random number generator to control the samples (if sampling used).

    fitting : bool, default=False
        whether or not this predictive call is done during fitting. Only useful for
        reward modeling: the likelihood is set to `"regression"` when `False` and
        `"classification"` when `True`.

    Returns
    -------
    predictive: torch.Tensor or tuple[torch.Tensor]
        For `likelihood='classification'`, a torch.Tensor is returned with
        a distribution over classes (similar to a Softmax).
        For `likelihood='regression'`, a tuple of torch.Tensor is returned
        with the mean and the predictive variance.
        For `likelihood='regression'` and `joint=True`, a tuple of torch.Tensor
        is returned with the mean and the predictive covariance.
    """
    if pred_type not in [pred for pred in PredType]:
        raise ValueError("Only glm and nn supported as prediction types.")

    if link_approx not in [la for la in LinkApprox]:
        raise ValueError(f"Unsupported link approximation {link_approx}.")

    if pred_type == PredType.NN and link_approx != LinkApprox.MC:
        raise ValueError(
            "Only mc link approximation is supported for nn prediction type."
        )

    if generator is not None:
        if (
            not isinstance(generator, torch.Generator)
            or generator.device != self._device
        ):
            raise ValueError("Invalid random generator (check type and device).")

    likelihood = self.likelihood
    if likelihood == Likelihood.REWARD_MODELING:
        likelihood = Likelihood.CLASSIFICATION if fitting else Likelihood.REGRESSION

    if pred_type == PredType.GLM:
        return self._glm_forward_call(
            x, likelihood, joint, link_approx, n_samples, diagonal_output
        )
    else:
        if likelihood == Likelihood.REGRESSION:
            samples = self._nn_predictive_samples(x, n_samples, **model_kwargs)
            return samples.mean(dim=0), samples.var(dim=0)
        else:  # classification; the average is computed online
            return self._nn_predictive_classification(x, n_samples, **model_kwargs)

_glm_forward_call #

_glm_forward_call(x: Tensor | MutableMapping, likelihood: Likelihood | str, joint: bool = False, link_approx: LinkApprox | str = LinkApprox.PROBIT, n_samples: int = 100, diagonal_output: bool = False) -> Tensor | tuple[Tensor, Tensor]

Compute the posterior predictive on input data x for "glm" pred type.

Parameters:

  • x (Tensor or MutableMapping) –

    (batch_size, input_shape) if tensor. If MutableMapping, must contain the said tensor.

  • likelihood (Likelihood or str in {'classification', 'regression', 'reward_modeling'}) –

    determines the log likelihood Hessian approximation.

  • link_approx (('mc', 'probit', 'bridge', 'bridge_norm'), default: 'mc' ) –

    how to approximate the classification link function for the 'glm'. For pred_type='nn', only 'mc' is possible.

  • joint (bool, default: False ) –

    Whether to output a joint predictive distribution in regression with pred_type='glm'. If set to True, the predictive distribution has the same form as GP posterior, i.e. N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]). If False, then only outputs the marginal predictive distribution. Only available for regression and GLM predictive.

  • n_samples (int, default: 100 ) –

    number of samples for link_approx='mc'.

  • diagonal_output (bool, default: False ) –

    whether to use a diagonalized posterior predictive on the outputs. Only works for pred_type='glm' and link_approx='mc'.

Returns:

  • predictive ( Tensor or tuple[Tensor] ) –

    For likelihood='classification', a torch.Tensor is returned with a distribution over classes (similar to a Softmax). For likelihood='regression', a tuple of torch.Tensor is returned with the mean and the predictive variance. For likelihood='regression' and joint=True, a tuple of torch.Tensor is returned with the mean and the predictive covariance.

Source code in laplace/baselaplace.py
def _glm_forward_call(
    self,
    x: torch.Tensor | MutableMapping,
    likelihood: Likelihood | str,
    joint: bool = False,
    link_approx: LinkApprox | str = LinkApprox.PROBIT,
    n_samples: int = 100,
    diagonal_output: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
    """Compute the posterior predictive on input data `x` for "glm" pred type.

    Parameters
    ----------
    x : torch.Tensor or MutableMapping
        `(batch_size, input_shape)` if tensor. If MutableMapping, must contain
        the said tensor.

    likelihood : Likelihood or str in {'classification', 'regression', 'reward_modeling'}
        determines the log likelihood Hessian approximation.

    link_approx : {'mc', 'probit', 'bridge', 'bridge_norm'}
        how to approximate the classification link function for the `'glm'`.
        For `pred_type='nn'`, only 'mc' is possible.

    joint : bool
        Whether to output a joint predictive distribution in regression with
        `pred_type='glm'`. If set to `True`, the predictive distribution
        has the same form as GP posterior, i.e. N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]).
        If `False`, then only outputs the marginal predictive distribution.
        Only available for regression and GLM predictive.

    n_samples : int
        number of samples for `link_approx='mc'`.

    diagonal_output : bool
        whether to use a diagonalized posterior predictive on the outputs.
        Only works for `pred_type='glm'` and `link_approx='mc'`.

    Returns
    -------
    predictive: torch.Tensor or tuple[torch.Tensor]
        For `likelihood='classification'`, a torch.Tensor is returned with
        a distribution over classes (similar to a Softmax).
        For `likelihood='regression'`, a tuple of torch.Tensor is returned
        with the mean and the predictive variance.
        For `likelihood='regression'` and `joint=True`, a tuple of torch.Tensor
        is returned with the mean and the predictive covariance.
    """
    f_mu, f_var = self._glm_predictive_distribution(
        x, joint=joint and likelihood == Likelihood.REGRESSION
    )

    if likelihood == Likelihood.REGRESSION:
        if diagonal_output and not joint:
            f_var = torch.diagonal(f_var, dim1=-2, dim2=-1)
        return f_mu, f_var

    if link_approx == LinkApprox.MC:
        return self._glm_predictive_samples(
            f_mu,
            f_var,
            n_samples=n_samples,
            diagonal_output=diagonal_output,
        ).mean(dim=0)
    elif link_approx == LinkApprox.PROBIT:
        kappa = 1 / torch.sqrt(1.0 + np.pi / 8 * f_var.diagonal(dim1=1, dim2=2))
        return torch.softmax(kappa * f_mu, dim=-1)
    elif "bridge" in link_approx:
        # zero mean correction
        f_mu -= (
            f_var.sum(-1)
            * f_mu.sum(-1).reshape(-1, 1)
            / f_var.sum(dim=(1, 2)).reshape(-1, 1)
        )
        f_var -= torch.einsum(
            "bi,bj->bij", f_var.sum(-1), f_var.sum(-2)
        ) / f_var.sum(dim=(1, 2)).reshape(-1, 1, 1)

        # Laplace Bridge
        _, K = f_mu.size(0), f_mu.size(-1)
        f_var_diag = torch.diagonal(f_var, dim1=1, dim2=2)

        # optional: variance correction
        if link_approx == LinkApprox.BRIDGE_NORM:
            f_var_diag_mean = f_var_diag.mean(dim=1)
            f_var_diag_mean /= torch.as_tensor([K / 2], device=self._device).sqrt()
            f_mu /= f_var_diag_mean.sqrt().unsqueeze(-1)
            f_var_diag /= f_var_diag_mean.unsqueeze(-1)

        sum_exp = torch.exp(-f_mu).sum(dim=1).unsqueeze(-1)
        alpha = (1 - 2 / K + f_mu.exp() / K**2 * sum_exp) / f_var_diag
        return torch.nan_to_num(alpha / alpha.sum(dim=1).unsqueeze(-1), nan=1.0)
    else:
        raise ValueError(
            "Prediction path invalid. Check the likelihood, pred_type, link_approx combination!"
        )

_glm_predictive_samples #

_glm_predictive_samples(f_mu: Tensor, f_var: Tensor, n_samples: int, diagonal_output: bool = False, generator: Generator | None = None) -> Tensor

Sample from the posterior predictive on input data x using "glm" prediction type.

Parameters:

  • f_mu (Tensor or MutableMapping) –

    glm predictive mean (batch_size, output_shape)

  • f_var (Tensor or MutableMapping) –

    glm predictive covariances (batch_size, output_shape, output_shape)

  • n_samples (int) –

    number of samples

  • diagonal_output (bool, default: False ) –

    whether to use a diagonalized glm posterior predictive on the outputs.

  • generator (Generator, default: None ) –

    random number generator to control the samples (if sampling used)

Returns:

  • samples ( Tensor ) –

    samples (n_samples, batch_size, output_shape)

Source code in laplace/baselaplace.py
def _glm_predictive_samples(
    self,
    f_mu: torch.Tensor,
    f_var: torch.Tensor,
    n_samples: int,
    diagonal_output: bool = False,
    generator: torch.Generator | None = None,
) -> torch.Tensor:
    """Sample from the posterior predictive on input data `x` using "glm" prediction
    type.

    Parameters
    ----------
    f_mu : torch.Tensor or MutableMapping
        glm predictive mean `(batch_size, output_shape)`

    f_var : torch.Tensor or MutableMapping
        glm predictive covariances `(batch_size, output_shape, output_shape)`

    n_samples : int
        number of samples

    diagonal_output : bool
        whether to use a diagonalized glm posterior predictive on the outputs.

    generator : torch.Generator, optional
        random number generator to control the samples (if sampling used)

    Returns
    -------
    samples : torch.Tensor
        samples `(n_samples, batch_size, output_shape)`
    """
    assert f_var.shape == torch.Size([f_mu.shape[0], f_mu.shape[1], f_mu.shape[1]])

    if diagonal_output:
        f_var = torch.diagonal(f_var, dim1=1, dim2=2)

    f_samples = normal_samples(f_mu, f_var, n_samples, generator)

    if self.likelihood == Likelihood.REGRESSION:
        return f_samples
    else:
        return torch.softmax(f_samples, dim=-1)

log_prob #

log_prob(value: Tensor, normalized: bool = True) -> Tensor

Compute the log probability under the (current) Laplace approximation.

Parameters:

  • value (Tensor) –
  • normalized (bool, default: True ) –

    whether to return log of a properly normalized Gaussian or just the terms that depend on value.

Returns:

  • log_prob ( Tensor ) –
Source code in laplace/baselaplace.py
def log_prob(self, value: torch.Tensor, normalized: bool = True) -> torch.Tensor:
    """Compute the log probability under the (current) Laplace approximation.

    Parameters
    ----------
    value: torch.Tensor
    normalized : bool, default=True
        whether to return log of a properly normalized Gaussian or just the
        terms that depend on `value`.

    Returns
    -------
    log_prob : torch.Tensor
    """
    if not normalized:
        return -self.square_norm(value) / 2
    log_prob = (
        -self.n_params / 2 * log(2 * pi) + self.log_det_posterior_precision / 2
    )
    log_prob -= self.square_norm(value) / 2
    return log_prob

predictive_samples #

predictive_samples(x: Tensor | MutableMapping[str, Tensor | Any], pred_type: PredType | str = PredType.GLM, n_samples: int = 100, diagonal_output: bool = False, generator: Generator | None = None) -> Tensor

Sample from the posterior predictive on input data x. Can be used, for example, for Thompson sampling.

Parameters:

  • x (Tensor or MutableMapping) –

    input data (batch_size, input_shape)

  • pred_type (('glm', 'nn'), default: 'glm' ) –

    type of posterior predictive, linearized GLM predictive or neural network sampling predictive. The GLM predictive is consistent with the curvature approximations used here.

  • n_samples (int, default: 100 ) –

    number of samples

  • diagonal_output (bool, default: False ) –

    whether to use a diagonalized glm posterior predictive on the outputs. Only applies when pred_type='glm'.

  • generator (Generator, default: None ) –

    random number generator to control the samples (if sampling used)

Returns:

  • samples ( Tensor ) –

    samples (n_samples, batch_size, output_shape)

Source code in laplace/baselaplace.py
def predictive_samples(
    self,
    x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
    pred_type: PredType | str = PredType.GLM,
    n_samples: int = 100,
    diagonal_output: bool = False,
    generator: torch.Generator | None = None,
) -> torch.Tensor:
    """Sample from the posterior predictive on input data `x`.
    Can be used, for example, for Thompson sampling.

    Parameters
    ----------
    x : torch.Tensor or MutableMapping
        input data `(batch_size, input_shape)`

    pred_type : {'glm', 'nn'}, default='glm'
        type of posterior predictive, linearized GLM predictive or neural
        network sampling predictive. The GLM predictive is consistent with
        the curvature approximations used here.

    n_samples : int
        number of samples

    diagonal_output : bool
        whether to use a diagonalized glm posterior predictive on the outputs.
        Only applies when `pred_type='glm'`.

    generator : torch.Generator, optional
        random number generator to control the samples (if sampling used)

    Returns
    -------
    samples : torch.Tensor
        samples `(n_samples, batch_size, output_shape)`
    """
    if pred_type not in PredType.__members__.values():
        raise ValueError("Only glm and nn supported as prediction types.")

    if pred_type == PredType.GLM:
        f_mu, f_var = self._glm_predictive_distribution(x)
        return self._glm_predictive_samples(
            f_mu, f_var, n_samples, diagonal_output, generator
        )

    else:  # 'nn'
        return self._nn_predictive_samples(x, n_samples, generator)

_check_subnetwork_indices #

_check_subnetwork_indices(subnetwork_indices: LongTensor | None) -> None

Check that subnetwork indices are valid indices of the vectorized model parameters (i.e. torch.nn.utils.parameters_to_vector(model.parameters())).

Source code in laplace/subnetlaplace.py
def _check_subnetwork_indices(
    self, subnetwork_indices: torch.LongTensor | None
) -> None:
    """Check that subnetwork indices are valid indices of the vectorized model parameters
    (i.e. `torch.nn.utils.parameters_to_vector(model.parameters())`).
    """
    if subnetwork_indices is None:
        raise ValueError("Subnetwork indices cannot be None.")
    elif not (
        isinstance(subnetwork_indices, torch.LongTensor)
        and subnetwork_indices.numel() > 0
        and len(subnetwork_indices.shape) == 1
    ):
        raise ValueError(
            "Subnetwork indices must be non-empty 1-dimensional torch.LongTensor."
        )
    elif not (
        len(subnetwork_indices[subnetwork_indices < 0]) == 0
        and len(subnetwork_indices[subnetwork_indices >= self.n_params]) == 0
    ):
        raise ValueError(
            f"Subnetwork indices must lie between 0 and n_params={self.n_params}."
        )
    elif not (len(subnetwork_indices.unique()) == len(subnetwork_indices)):
        raise ValueError("Subnetwork indices must not contain duplicate entries.")

FullSubnetLaplace #

FullSubnetLaplace(model: Module, likelihood: Likelihood | str, subnetwork_indices: LongTensor, sigma_noise: float | Tensor = 1.0, prior_precision: float | Tensor = 1.0, prior_mean: float | Tensor = 0.0, temperature: float = 1.0, backend: Type[CurvatureInterface] | None = None, backend_kwargs: dict | None = None, asdl_fisher_kwargs: dict | None = None)

Bases: SubnetLaplace, FullLaplace

Subnetwork Laplace approximation with full, i.e., dense, log likelihood Hessian approximation and hence posterior precision. Based on the chosen backend parameter, the full approximation can be, for example, a generalized Gauss-Newton matrix. Mathematically, we have \(P \in \mathbb{R}^{P \times P}\). See FullLaplace, SubnetLaplace, and BaseLaplace for the full interface.

Source code in laplace/subnetlaplace.py
def __init__(
    self,
    model: nn.Module,
    likelihood: Likelihood | str,
    subnetwork_indices: torch.LongTensor,
    sigma_noise: float | torch.Tensor = 1.0,
    prior_precision: float | torch.Tensor = 1.0,
    prior_mean: float | torch.Tensor = 0.0,
    temperature: float = 1.0,
    backend: Type[CurvatureInterface] | None = None,
    backend_kwargs: dict | None = None,
    asdl_fisher_kwargs: dict | None = None,
) -> None:
    if asdl_fisher_kwargs is not None:
        raise ValueError("Subnetwork Laplace does not support asdl_fisher_kwargs.")

    self.H = None
    super().__init__(
        model,
        likelihood,
        sigma_noise=sigma_noise,
        prior_precision=prior_precision,
        prior_mean=prior_mean,
        temperature=temperature,
        backend=backend,
        backend_kwargs=backend_kwargs,
    )

    if backend is not None and not issubclass(backend, (GGNInterface, EFInterface)):
        raise ValueError("SubnetLaplace can only be used with GGN and EF.")

    # check validity of subnetwork indices and pass them to backend
    self._check_subnetwork_indices(subnetwork_indices)
    self.backend.subnetwork_indices = subnetwork_indices
    self.n_params_subnet = len(subnetwork_indices)
    self._init_H()

log_likelihood #

log_likelihood: Tensor

Compute log likelihood on the training data after .fit() has been called. The log likelihood is computed on-demand based on the loss and, for example, the observation noise which makes it differentiable in the latter for iterative updates.

Returns:

  • log_likelihood ( Tensor ) –

prior_precision_diag #

prior_precision_diag: Tensor

Obtain the diagonal prior precision \(p_0\) constructed from either a scalar or diagonal prior precision.

Returns:

  • prior_precision_diag ( Tensor ) –

log_det_prior_precision #

log_det_prior_precision: Tensor

Compute log determinant of the prior precision \(\log \det P_0\)

Returns:

  • log_det ( Tensor ) –

log_det_ratio #

log_det_ratio: Tensor

Compute the log determinant ratio, a part of the log marginal likelihood.

\[ \log \frac{\det P}{\det P_0} = \log \det P - \log \det P_0 \]

Returns:

  • log_det_ratio ( Tensor ) –

posterior_precision #

posterior_precision: Tensor

Posterior precision \(P\).

Returns:

  • precision ( tensor ) –

    (parameters, parameters)

posterior_scale #

posterior_scale: Tensor

Posterior scale (square root of the covariance), i.e., \(P^{-\frac{1}{2}}\).

Returns:

  • scale ( tensor ) –

    (parameters, parameters)

posterior_covariance #

posterior_covariance: Tensor

Posterior covariance, i.e., \(P^{-1}\).

Returns:

  • covariance ( tensor ) –

    (parameters, parameters)

log_marginal_likelihood #

log_marginal_likelihood(prior_precision: Tensor | None = None, sigma_noise: Tensor | None = None) -> Tensor

Compute the Laplace approximation to the log marginal likelihood subject to specific Hessian approximations that subclasses implement. Requires that the Laplace approximation has been fit before. The resulting torch.Tensor is differentiable in prior_precision and sigma_noise if these have gradients enabled. By passing prior_precision or sigma_noise, the current value is overwritten. This is useful for iterating on the log marginal likelihood.

Parameters:

  • prior_precision (Tensor, default: None ) –

    prior precision if should be changed from current prior_precision value

  • sigma_noise (Tensor, default: None ) –

    observation noise standard deviation if should be changed

Returns:

  • log_marglik ( Tensor ) –
Source code in laplace/baselaplace.py
def log_marginal_likelihood(
    self,
    prior_precision: torch.Tensor | None = None,
    sigma_noise: torch.Tensor | None = None,
) -> torch.Tensor:
    """Compute the Laplace approximation to the log marginal likelihood subject
    to specific Hessian approximations that subclasses implement.
    Requires that the Laplace approximation has been fit before.
    The resulting torch.Tensor is differentiable in `prior_precision` and
    `sigma_noise` if these have gradients enabled.
    By passing `prior_precision` or `sigma_noise`, the current value is
    overwritten. This is useful for iterating on the log marginal likelihood.

    Parameters
    ----------
    prior_precision : torch.Tensor, optional
        prior precision if should be changed from current `prior_precision` value
    sigma_noise : torch.Tensor, optional
        observation noise standard deviation if should be changed

    Returns
    -------
    log_marglik : torch.Tensor
    """
    # update prior precision (useful when iterating on marglik)
    if prior_precision is not None:
        self.prior_precision = prior_precision

    # update sigma_noise (useful when iterating on marglik)
    if sigma_noise is not None:
        if self.likelihood != Likelihood.REGRESSION:
            raise ValueError("Can only change sigma_noise for regression.")

        self.sigma_noise = sigma_noise

    return self.log_likelihood - 0.5 * (self.log_det_ratio + self.scatter)

__call__ #

__call__(x: Tensor | MutableMapping[str, Tensor | Any], pred_type: PredType | str = PredType.GLM, joint: bool = False, link_approx: LinkApprox | str = LinkApprox.PROBIT, n_samples: int = 100, diagonal_output: bool = False, generator: Generator | None = None, fitting: bool = False, **model_kwargs: dict[str, Any]) -> Tensor | tuple[Tensor, Tensor]

Compute the posterior predictive on input data x.

Parameters:

  • x (Tensor or MutableMapping) –

    (batch_size, input_shape) if tensor. If MutableMapping, must contain the said tensor.

  • pred_type (('glm', 'nn'), default: 'glm' ) –

    type of posterior predictive, linearized GLM predictive or neural network sampling predictive. The GLM predictive is consistent with the curvature approximations used here. When Laplace is done only on subset of parameters (i.e. some grad are disabled), only nn predictive is supported.

  • link_approx (('mc', 'probit', 'bridge', 'bridge_norm'), default: 'mc' ) –

    how to approximate the classification link function for the 'glm'. For pred_type='nn', only 'mc' is possible.

  • joint (bool, default: False ) –

    Whether to output a joint predictive distribution in regression with pred_type='glm'. If set to True, the predictive distribution has the same form as GP posterior, i.e. N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]). If False, then only outputs the marginal predictive distribution. Only available for regression and GLM predictive.

  • n_samples (int, default: 100 ) –

    number of samples for link_approx='mc'.

  • diagonal_output (bool, default: False ) –

    whether to use a diagonalized posterior predictive on the outputs. Only works for pred_type='glm' when joint=False in regression. In the case of last-layer Laplace with a diagonal or Kron Hessian, setting this to True makes computation much(!) faster for large number of outputs.

  • generator (Generator, default: None ) –

    random number generator to control the samples (if sampling used).

  • fitting (bool, default: False ) –

    whether or not this predictive call is done during fitting. Only useful for reward modeling: the likelihood is set to "regression" when False and "classification" when True.

Returns:

  • predictive ( Tensor or tuple[Tensor] ) –

    For likelihood='classification', a torch.Tensor is returned with a distribution over classes (similar to a Softmax). For likelihood='regression', a tuple of torch.Tensor is returned with the mean and the predictive variance. For likelihood='regression' and joint=True, a tuple of torch.Tensor is returned with the mean and the predictive covariance.

Source code in laplace/baselaplace.py
def __call__(
    self,
    x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
    pred_type: PredType | str = PredType.GLM,
    joint: bool = False,
    link_approx: LinkApprox | str = LinkApprox.PROBIT,
    n_samples: int = 100,
    diagonal_output: bool = False,
    generator: torch.Generator | None = None,
    fitting: bool = False,
    **model_kwargs: dict[str, Any],
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
    """Compute the posterior predictive on input data `x`.

    Parameters
    ----------
    x : torch.Tensor or MutableMapping
        `(batch_size, input_shape)` if tensor. If MutableMapping, must contain
        the said tensor.

    pred_type : {'glm', 'nn'}, default='glm'
        type of posterior predictive, linearized GLM predictive or neural
        network sampling predictive. The GLM predictive is consistent with
        the curvature approximations used here. When Laplace is done only
        on subset of parameters (i.e. some grad are disabled),
        only `nn` predictive is supported.

    link_approx : {'mc', 'probit', 'bridge', 'bridge_norm'}
        how to approximate the classification link function for the `'glm'`.
        For `pred_type='nn'`, only 'mc' is possible.

    joint : bool
        Whether to output a joint predictive distribution in regression with
        `pred_type='glm'`. If set to `True`, the predictive distribution
        has the same form as GP posterior, i.e. N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]).
        If `False`, then only outputs the marginal predictive distribution.
        Only available for regression and GLM predictive.

    n_samples : int
        number of samples for `link_approx='mc'`.

    diagonal_output : bool
        whether to use a diagonalized posterior predictive on the outputs.
        Only works for `pred_type='glm'` when `joint=False` in regression.
        In the case of last-layer Laplace with a diagonal or Kron Hessian,
        setting this to `True` makes computation much(!) faster for large
        number of outputs.

    generator : torch.Generator, optional
        random number generator to control the samples (if sampling used).

    fitting : bool, default=False
        whether or not this predictive call is done during fitting. Only useful for
        reward modeling: the likelihood is set to `"regression"` when `False` and
        `"classification"` when `True`.

    Returns
    -------
    predictive: torch.Tensor or tuple[torch.Tensor]
        For `likelihood='classification'`, a torch.Tensor is returned with
        a distribution over classes (similar to a Softmax).
        For `likelihood='regression'`, a tuple of torch.Tensor is returned
        with the mean and the predictive variance.
        For `likelihood='regression'` and `joint=True`, a tuple of torch.Tensor
        is returned with the mean and the predictive covariance.
    """
    if pred_type not in [pred for pred in PredType]:
        raise ValueError("Only glm and nn supported as prediction types.")

    if link_approx not in [la for la in LinkApprox]:
        raise ValueError(f"Unsupported link approximation {link_approx}.")

    if pred_type == PredType.NN and link_approx != LinkApprox.MC:
        raise ValueError(
            "Only mc link approximation is supported for nn prediction type."
        )

    if generator is not None:
        if (
            not isinstance(generator, torch.Generator)
            or generator.device != self._device
        ):
            raise ValueError("Invalid random generator (check type and device).")

    likelihood = self.likelihood
    if likelihood == Likelihood.REWARD_MODELING:
        likelihood = Likelihood.CLASSIFICATION if fitting else Likelihood.REGRESSION

    if pred_type == PredType.GLM:
        return self._glm_forward_call(
            x, likelihood, joint, link_approx, n_samples, diagonal_output
        )
    else:
        if likelihood == Likelihood.REGRESSION:
            samples = self._nn_predictive_samples(x, n_samples, **model_kwargs)
            return samples.mean(dim=0), samples.var(dim=0)
        else:  # classification; the average is computed online
            return self._nn_predictive_classification(x, n_samples, **model_kwargs)

_glm_forward_call #

_glm_forward_call(x: Tensor | MutableMapping, likelihood: Likelihood | str, joint: bool = False, link_approx: LinkApprox | str = LinkApprox.PROBIT, n_samples: int = 100, diagonal_output: bool = False) -> Tensor | tuple[Tensor, Tensor]

Compute the posterior predictive on input data x for "glm" pred type.

Parameters:

  • x (Tensor or MutableMapping) –

    (batch_size, input_shape) if tensor. If MutableMapping, must contain the said tensor.

  • likelihood (Likelihood or str in {'classification', 'regression', 'reward_modeling'}) –

    determines the log likelihood Hessian approximation.

  • link_approx (('mc', 'probit', 'bridge', 'bridge_norm'), default: 'mc' ) –

    how to approximate the classification link function for the 'glm'. For pred_type='nn', only 'mc' is possible.

  • joint (bool, default: False ) –

    Whether to output a joint predictive distribution in regression with pred_type='glm'. If set to True, the predictive distribution has the same form as GP posterior, i.e. N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]). If False, then only outputs the marginal predictive distribution. Only available for regression and GLM predictive.

  • n_samples (int, default: 100 ) –

    number of samples for link_approx='mc'.

  • diagonal_output (bool, default: False ) –

    whether to use a diagonalized posterior predictive on the outputs. Only works for pred_type='glm' and link_approx='mc'.

Returns:

  • predictive ( Tensor or tuple[Tensor] ) –

    For likelihood='classification', a torch.Tensor is returned with a distribution over classes (similar to a Softmax). For likelihood='regression', a tuple of torch.Tensor is returned with the mean and the predictive variance. For likelihood='regression' and joint=True, a tuple of torch.Tensor is returned with the mean and the predictive covariance.

Source code in laplace/baselaplace.py
def _glm_forward_call(
    self,
    x: torch.Tensor | MutableMapping,
    likelihood: Likelihood | str,
    joint: bool = False,
    link_approx: LinkApprox | str = LinkApprox.PROBIT,
    n_samples: int = 100,
    diagonal_output: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
    """Compute the posterior predictive on input data `x` for "glm" pred type.

    Parameters
    ----------
    x : torch.Tensor or MutableMapping
        `(batch_size, input_shape)` if tensor. If MutableMapping, must contain
        the said tensor.

    likelihood : Likelihood or str in {'classification', 'regression', 'reward_modeling'}
        determines the log likelihood Hessian approximation.

    link_approx : {'mc', 'probit', 'bridge', 'bridge_norm'}
        how to approximate the classification link function for the `'glm'`.
        For `pred_type='nn'`, only 'mc' is possible.

    joint : bool
        Whether to output a joint predictive distribution in regression with
        `pred_type='glm'`. If set to `True`, the predictive distribution
        has the same form as GP posterior, i.e. N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]).
        If `False`, then only outputs the marginal predictive distribution.
        Only available for regression and GLM predictive.

    n_samples : int
        number of samples for `link_approx='mc'`.

    diagonal_output : bool
        whether to use a diagonalized posterior predictive on the outputs.
        Only works for `pred_type='glm'` and `link_approx='mc'`.

    Returns
    -------
    predictive: torch.Tensor or tuple[torch.Tensor]
        For `likelihood='classification'`, a torch.Tensor is returned with
        a distribution over classes (similar to a Softmax).
        For `likelihood='regression'`, a tuple of torch.Tensor is returned
        with the mean and the predictive variance.
        For `likelihood='regression'` and `joint=True`, a tuple of torch.Tensor
        is returned with the mean and the predictive covariance.
    """
    f_mu, f_var = self._glm_predictive_distribution(
        x, joint=joint and likelihood == Likelihood.REGRESSION
    )

    if likelihood == Likelihood.REGRESSION:
        if diagonal_output and not joint:
            f_var = torch.diagonal(f_var, dim1=-2, dim2=-1)
        return f_mu, f_var

    if link_approx == LinkApprox.MC:
        return self._glm_predictive_samples(
            f_mu,
            f_var,
            n_samples=n_samples,
            diagonal_output=diagonal_output,
        ).mean(dim=0)
    elif link_approx == LinkApprox.PROBIT:
        kappa = 1 / torch.sqrt(1.0 + np.pi / 8 * f_var.diagonal(dim1=1, dim2=2))
        return torch.softmax(kappa * f_mu, dim=-1)
    elif "bridge" in link_approx:
        # zero mean correction
        f_mu -= (
            f_var.sum(-1)
            * f_mu.sum(-1).reshape(-1, 1)
            / f_var.sum(dim=(1, 2)).reshape(-1, 1)
        )
        f_var -= torch.einsum(
            "bi,bj->bij", f_var.sum(-1), f_var.sum(-2)
        ) / f_var.sum(dim=(1, 2)).reshape(-1, 1, 1)

        # Laplace Bridge
        _, K = f_mu.size(0), f_mu.size(-1)
        f_var_diag = torch.diagonal(f_var, dim1=1, dim2=2)

        # optional: variance correction
        if link_approx == LinkApprox.BRIDGE_NORM:
            f_var_diag_mean = f_var_diag.mean(dim=1)
            f_var_diag_mean /= torch.as_tensor([K / 2], device=self._device).sqrt()
            f_mu /= f_var_diag_mean.sqrt().unsqueeze(-1)
            f_var_diag /= f_var_diag_mean.unsqueeze(-1)

        sum_exp = torch.exp(-f_mu).sum(dim=1).unsqueeze(-1)
        alpha = (1 - 2 / K + f_mu.exp() / K**2 * sum_exp) / f_var_diag
        return torch.nan_to_num(alpha / alpha.sum(dim=1).unsqueeze(-1), nan=1.0)
    else:
        raise ValueError(
            "Prediction path invalid. Check the likelihood, pred_type, link_approx combination!"
        )

_glm_predictive_samples #

_glm_predictive_samples(f_mu: Tensor, f_var: Tensor, n_samples: int, diagonal_output: bool = False, generator: Generator | None = None) -> Tensor

Sample from the posterior predictive on input data x using "glm" prediction type.

Parameters:

  • f_mu (Tensor or MutableMapping) –

    glm predictive mean (batch_size, output_shape)

  • f_var (Tensor or MutableMapping) –

    glm predictive covariances (batch_size, output_shape, output_shape)

  • n_samples (int) –

    number of samples

  • diagonal_output (bool, default: False ) –

    whether to use a diagonalized glm posterior predictive on the outputs.

  • generator (Generator, default: None ) –

    random number generator to control the samples (if sampling used)

Returns:

  • samples ( Tensor ) –

    samples (n_samples, batch_size, output_shape)

Source code in laplace/baselaplace.py
def _glm_predictive_samples(
    self,
    f_mu: torch.Tensor,
    f_var: torch.Tensor,
    n_samples: int,
    diagonal_output: bool = False,
    generator: torch.Generator | None = None,
) -> torch.Tensor:
    """Sample from the posterior predictive on input data `x` using "glm" prediction
    type.

    Parameters
    ----------
    f_mu : torch.Tensor or MutableMapping
        glm predictive mean `(batch_size, output_shape)`

    f_var : torch.Tensor or MutableMapping
        glm predictive covariances `(batch_size, output_shape, output_shape)`

    n_samples : int
        number of samples

    diagonal_output : bool
        whether to use a diagonalized glm posterior predictive on the outputs.

    generator : torch.Generator, optional
        random number generator to control the samples (if sampling used)

    Returns
    -------
    samples : torch.Tensor
        samples `(n_samples, batch_size, output_shape)`
    """
    assert f_var.shape == torch.Size([f_mu.shape[0], f_mu.shape[1], f_mu.shape[1]])

    if diagonal_output:
        f_var = torch.diagonal(f_var, dim1=1, dim2=2)

    f_samples = normal_samples(f_mu, f_var, n_samples, generator)

    if self.likelihood == Likelihood.REGRESSION:
        return f_samples
    else:
        return torch.softmax(f_samples, dim=-1)

log_prob #

log_prob(value: Tensor, normalized: bool = True) -> Tensor

Compute the log probability under the (current) Laplace approximation.

Parameters:

  • value (Tensor) –
  • normalized (bool, default: True ) –

    whether to return log of a properly normalized Gaussian or just the terms that depend on value.

Returns:

  • log_prob ( Tensor ) –
Source code in laplace/baselaplace.py
def log_prob(self, value: torch.Tensor, normalized: bool = True) -> torch.Tensor:
    """Compute the log probability under the (current) Laplace approximation.

    Parameters
    ----------
    value: torch.Tensor
    normalized : bool, default=True
        whether to return log of a properly normalized Gaussian or just the
        terms that depend on `value`.

    Returns
    -------
    log_prob : torch.Tensor
    """
    if not normalized:
        return -self.square_norm(value) / 2
    log_prob = (
        -self.n_params / 2 * log(2 * pi) + self.log_det_posterior_precision / 2
    )
    log_prob -= self.square_norm(value) / 2
    return log_prob

predictive_samples #

predictive_samples(x: Tensor | MutableMapping[str, Tensor | Any], pred_type: PredType | str = PredType.GLM, n_samples: int = 100, diagonal_output: bool = False, generator: Generator | None = None) -> Tensor

Sample from the posterior predictive on input data x. Can be used, for example, for Thompson sampling.

Parameters:

  • x (Tensor or MutableMapping) –

    input data (batch_size, input_shape)

  • pred_type (('glm', 'nn'), default: 'glm' ) –

    type of posterior predictive, linearized GLM predictive or neural network sampling predictive. The GLM predictive is consistent with the curvature approximations used here.

  • n_samples (int, default: 100 ) –

    number of samples

  • diagonal_output (bool, default: False ) –

    whether to use a diagonalized glm posterior predictive on the outputs. Only applies when pred_type='glm'.

  • generator (Generator, default: None ) –

    random number generator to control the samples (if sampling used)

Returns:

  • samples ( Tensor ) –

    samples (n_samples, batch_size, output_shape)

Source code in laplace/baselaplace.py
def predictive_samples(
    self,
    x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
    pred_type: PredType | str = PredType.GLM,
    n_samples: int = 100,
    diagonal_output: bool = False,
    generator: torch.Generator | None = None,
) -> torch.Tensor:
    """Sample from the posterior predictive on input data `x`.
    Can be used, for example, for Thompson sampling.

    Parameters
    ----------
    x : torch.Tensor or MutableMapping
        input data `(batch_size, input_shape)`

    pred_type : {'glm', 'nn'}, default='glm'
        type of posterior predictive, linearized GLM predictive or neural
        network sampling predictive. The GLM predictive is consistent with
        the curvature approximations used here.

    n_samples : int
        number of samples

    diagonal_output : bool
        whether to use a diagonalized glm posterior predictive on the outputs.
        Only applies when `pred_type='glm'`.

    generator : torch.Generator, optional
        random number generator to control the samples (if sampling used)

    Returns
    -------
    samples : torch.Tensor
        samples `(n_samples, batch_size, output_shape)`
    """
    if pred_type not in PredType.__members__.values():
        raise ValueError("Only glm and nn supported as prediction types.")

    if pred_type == PredType.GLM:
        f_mu, f_var = self._glm_predictive_distribution(x)
        return self._glm_predictive_samples(
            f_mu, f_var, n_samples, diagonal_output, generator
        )

    else:  # 'nn'
        return self._nn_predictive_samples(x, n_samples, generator)

_check_subnetwork_indices #

_check_subnetwork_indices(subnetwork_indices: LongTensor | None) -> None

Check that subnetwork indices are valid indices of the vectorized model parameters (i.e. torch.nn.utils.parameters_to_vector(model.parameters())).

Source code in laplace/subnetlaplace.py
def _check_subnetwork_indices(
    self, subnetwork_indices: torch.LongTensor | None
) -> None:
    """Check that subnetwork indices are valid indices of the vectorized model parameters
    (i.e. `torch.nn.utils.parameters_to_vector(model.parameters())`).
    """
    if subnetwork_indices is None:
        raise ValueError("Subnetwork indices cannot be None.")
    elif not (
        isinstance(subnetwork_indices, torch.LongTensor)
        and subnetwork_indices.numel() > 0
        and len(subnetwork_indices.shape) == 1
    ):
        raise ValueError(
            "Subnetwork indices must be non-empty 1-dimensional torch.LongTensor."
        )
    elif not (
        len(subnetwork_indices[subnetwork_indices < 0]) == 0
        and len(subnetwork_indices[subnetwork_indices >= self.n_params]) == 0
    ):
        raise ValueError(
            f"Subnetwork indices must lie between 0 and n_params={self.n_params}."
        )
    elif not (len(subnetwork_indices.unique()) == len(subnetwork_indices)):
        raise ValueError("Subnetwork indices must not contain duplicate entries.")