Skip to content

laplace.lllaplace #

Classes:

  • LLLaplace

    Baseclass for all last-layer Laplace approximations in this library.

  • DiagLLLaplace

    Last-layer Laplace approximation with diagonal log likelihood Hessian approximation

  • KronLLLaplace

    Last-layer Laplace approximation with Kronecker factored log likelihood Hessian approximation

  • FullLLLaplace

    Last-layer Laplace approximation with full, i.e., dense, log likelihood Hessian approximation

LLLaplace #

LLLaplace(model: Module, likelihood: Likelihood | str, sigma_noise: float | Tensor = 1.0, prior_precision: float | Tensor = 1.0, prior_mean: float | Tensor = 0.0, temperature: float = 1.0, enable_backprop: bool = False, feature_reduction: FeatureReduction | str | None = None, dict_key_x: str = 'input_ids', dict_key_y: str = 'labels', backend: type[CurvatureInterface] | None = None, last_layer_name: str | None = None, backend_kwargs: dict[str, Any] | None = None, asdl_fisher_kwargs: dict[str, Any] | None = None)

Bases: ParametricLaplace

Baseclass for all last-layer Laplace approximations in this library. Subclasses specify the structure of the Hessian approximation. See BaseLaplace for the full interface.

A Laplace approximation is represented by a MAP which is given by the model parameter and a posterior precision or covariance specifying a Gaussian distribution \(\mathcal{N}(\theta_{MAP}, P^{-1})\). Here, only the parameters of the last layer of the neural network are treated probabilistically. The goal of this class is to compute the posterior precision \(P\) which sums as

\[ P = \sum_{n=1}^N \nabla^2_\theta \log p(\mathcal{D}_n \mid \theta) \vert_{\theta_{MAP}} + \nabla^2_\theta \log p(\theta) \vert_{\theta_{MAP}}. \]

Every subclass implements different approximations to the log likelihood Hessians, for example, a diagonal one. The prior is assumed to be Gaussian and therefore we have a simple form for \(\nabla^2_\theta \log p(\theta) \vert_{\theta_{MAP}} = P_0 \). In particular, we assume a scalar or diagonal prior precision so that in all cases \(P_0 = \textrm{diag}(p_0)\) and the structure of \(p_0\) can be varied.

Parameters:

  • model #

    (torch.nn.Module or `laplace.utils.feature_extractor.FeatureExtractor`) –
  • likelihood #

    (Likelihood or {'classification', 'regression'}) –

    determines the log likelihood Hessian approximation

  • sigma_noise #

    (Tensor or float, default: 1 ) –

    observation noise for the regression setting; must be 1 for classification

  • prior_precision #

    (Tensor or float, default: 1 ) –

    prior precision of a Gaussian prior (= weight decay); can be scalar, per-layer, or diagonal in the most general case

  • prior_mean #

    (Tensor or float, default: 0 ) –

    prior mean of a Gaussian prior, useful for continual learning

  • temperature #

    (float, default: 1 ) –

    temperature of the likelihood; lower temperature leads to more concentrated posterior and vice versa.

  • enable_backprop #

    (bool, default: False ) –

    whether to enable backprop to the input x through the Laplace predictive. Useful for e.g. Bayesian optimization.

  • feature_reduction #

    (FeatureReduction | str | None, default: None ) –

    when the last-layer features is a tensor of dim >= 3, this tells how to reduce it into a dim-2 tensor. E.g. in LLMs for non-language modeling problems, the penultultimate output is a tensor of shape (batch_size, seq_len, embd_dim). But the last layer maps (batch_size, embd_dim) to (batch_size, n_classes). Note: Make sure that this option faithfully reflects the reduction in the model definition. When inputting a string, available options are {'pick_first', 'pick_last', 'average'}.

  • dict_key_x #

    (str, default: 'input_ids' ) –

    The dictionary key under which the input tensor x is stored. Only has effect when the model takes a MutableMapping as the input. Useful for Huggingface LLM models.

  • dict_key_y #

    (str, default: 'labels' ) –

    The dictionary key under which the target tensor y is stored. Only has effect when the model takes a MutableMapping as the input. Useful for Huggingface LLM models.

  • backend #

    (subclasses of `laplace.curvature.CurvatureInterface`, default: None ) –

    backend for access to curvature/Hessian approximations

  • last_layer_name #

    (str | None, default: None ) –

    name of the model's last layer, if None it will be determined automatically

  • backend_kwargs #

    (dict, default: None ) –

    arguments passed to the backend on initialization, for example to set the number of MC samples for stochastic approximations.

Methods:

  • log_marginal_likelihood

    Compute the Laplace approximation to the log marginal likelihood subject

  • __call__

    Compute the posterior predictive on input data x.

  • square_norm

    Compute the square norm under post. Precision with value-self.mean as 𝛥:

  • log_prob

    Compute the log probability under the (current) Laplace approximation.

  • predictive_samples

    Sample from the posterior predictive on input data x.

  • functional_variance

    Compute functional variance for the 'glm' predictive:

  • functional_covariance

    Compute functional covariance for the 'glm' predictive:

  • sample

    Sample from the Laplace posterior approximation, i.e.,

  • fit

    Fit the local Laplace approximation at the parameters of the model.

  • functional_variance_fast

    Should be overriden if there exists a trick to make this fast!

Attributes:

  • log_likelihood (Tensor) –

    Compute log likelihood on the training data after .fit() has been called.

  • scatter (Tensor) –

    Computes the scatter, a term of the log marginal likelihood that

  • log_det_prior_precision (Tensor) –

    Compute log determinant of the prior precision

  • log_det_posterior_precision (Tensor) –

    Compute log determinant of the posterior precision

  • log_det_ratio (Tensor) –

    Compute the log determinant ratio, a part of the log marginal likelihood.

  • posterior_precision (Tensor) –

    Compute or return the posterior precision \(P\).

  • prior_precision_diag (Tensor) –

    Obtain the diagonal prior precision \(p_0\) constructed from either

Source code in laplace/lllaplace.py
def __init__(
    self,
    model: nn.Module,
    likelihood: Likelihood | str,
    sigma_noise: float | torch.Tensor = 1.0,
    prior_precision: float | torch.Tensor = 1.0,
    prior_mean: float | torch.Tensor = 0.0,
    temperature: float = 1.0,
    enable_backprop: bool = False,
    feature_reduction: FeatureReduction | str | None = None,
    dict_key_x: str = "input_ids",
    dict_key_y: str = "labels",
    backend: type[CurvatureInterface] | None = None,
    last_layer_name: str | None = None,
    backend_kwargs: dict[str, Any] | None = None,
    asdl_fisher_kwargs: dict[str, Any] | None = None,
):
    if asdl_fisher_kwargs is not None:
        raise ValueError("Last-layer Laplace does not support asdl_fisher_kwargs.")

    self.H = None
    super().__init__(
        model,
        likelihood,
        sigma_noise=sigma_noise,
        prior_precision=1.0,
        prior_mean=0.0,
        temperature=temperature,
        enable_backprop=enable_backprop,
        dict_key_x=dict_key_x,
        dict_key_y=dict_key_y,
        backend=backend,
        backend_kwargs=backend_kwargs,
    )
    self.model = FeatureExtractor(
        deepcopy(model),
        last_layer_name=last_layer_name,
        enable_backprop=enable_backprop,
        feature_reduction=feature_reduction,
    )

    if self.model.last_layer is None:
        self.mean: torch.Tensor | None = None
        self.n_params: int | None = None
        self.n_layers: int | None = None
        # ignore checks of prior mean setter temporarily, check on .fit()
        self._prior_precision: float | torch.Tensor = prior_precision
        self._prior_mean: float | torch.Tensor = prior_mean
    else:
        self.n_params: int = len(
            parameters_to_vector(self.model.last_layer.parameters())
        )
        self.n_layers: int | None = len(list(self.model.last_layer.parameters()))
        self.prior_precision: float | torch.Tensor = prior_precision
        self.prior_mean: float | torch.Tensor = prior_mean
        self.mean: float | torch.Tensor = self.prior_mean
        self._init_H()

    self._backend_kwargs["last_layer"] = True
    self._last_layer_name: str | None = last_layer_name

log_likelihood #

log_likelihood: Tensor

Compute log likelihood on the training data after .fit() has been called. The log likelihood is computed on-demand based on the loss and, for example, the observation noise which makes it differentiable in the latter for iterative updates.

Returns:

  • log_likelihood ( Tensor ) –

scatter #

scatter: Tensor

Computes the scatter, a term of the log marginal likelihood that corresponds to L-2 regularization: scatter = \((\theta_{MAP} - \mu_0)^{T} P_0 (\theta_{MAP} - \mu_0) \).

Returns:

  • scatter ( Tensor ) –

log_det_prior_precision #

log_det_prior_precision: Tensor

Compute log determinant of the prior precision \(\log \det P_0\)

Returns:

  • log_det ( Tensor ) –

log_det_posterior_precision #

log_det_posterior_precision: Tensor

Compute log determinant of the posterior precision \(\log \det P\) which depends on the subclasses structure used for the Hessian approximation.

Returns:

  • log_det ( Tensor ) –

log_det_ratio #

log_det_ratio: Tensor

Compute the log determinant ratio, a part of the log marginal likelihood.

\[ \log \frac{\det P}{\det P_0} = \log \det P - \log \det P_0 \]

Returns:

  • log_det_ratio ( Tensor ) –

posterior_precision #

posterior_precision: Tensor

Compute or return the posterior precision \(P\).

Returns:

  • posterior_prec ( Tensor ) –

prior_precision_diag #

prior_precision_diag: Tensor

Obtain the diagonal prior precision \(p_0\) constructed from either a scalar or diagonal prior precision.

Returns:

  • prior_precision_diag ( Tensor ) –

log_marginal_likelihood #

log_marginal_likelihood(prior_precision: Tensor | None = None, sigma_noise: Tensor | None = None) -> Tensor

Compute the Laplace approximation to the log marginal likelihood subject to specific Hessian approximations that subclasses implement. Requires that the Laplace approximation has been fit before. The resulting torch.Tensor is differentiable in prior_precision and sigma_noise if these have gradients enabled. By passing prior_precision or sigma_noise, the current value is overwritten. This is useful for iterating on the log marginal likelihood.

Parameters:

  • prior_precision #

    (Tensor, default: None ) –

    prior precision if should be changed from current prior_precision value

  • sigma_noise #

    (Tensor, default: None ) –

    observation noise standard deviation if should be changed

Returns:

  • log_marglik ( Tensor ) –
Source code in laplace/baselaplace.py
def log_marginal_likelihood(
    self,
    prior_precision: torch.Tensor | None = None,
    sigma_noise: torch.Tensor | None = None,
) -> torch.Tensor:
    """Compute the Laplace approximation to the log marginal likelihood subject
    to specific Hessian approximations that subclasses implement.
    Requires that the Laplace approximation has been fit before.
    The resulting torch.Tensor is differentiable in `prior_precision` and
    `sigma_noise` if these have gradients enabled.
    By passing `prior_precision` or `sigma_noise`, the current value is
    overwritten. This is useful for iterating on the log marginal likelihood.

    Parameters
    ----------
    prior_precision : torch.Tensor, optional
        prior precision if should be changed from current `prior_precision` value
    sigma_noise : torch.Tensor, optional
        observation noise standard deviation if should be changed

    Returns
    -------
    log_marglik : torch.Tensor
    """
    # update prior precision (useful when iterating on marglik)
    if prior_precision is not None:
        self.prior_precision = prior_precision

    # update sigma_noise (useful when iterating on marglik)
    if sigma_noise is not None:
        if self.likelihood != Likelihood.REGRESSION:
            raise ValueError("Can only change sigma_noise for regression.")

        self.sigma_noise = sigma_noise

    return self.log_likelihood - 0.5 * (self.log_det_ratio + self.scatter)

__call__ #

__call__(x: Tensor | MutableMapping[str, Tensor | Any], pred_type: PredType | str = GLM, joint: bool = False, link_approx: LinkApprox | str = PROBIT, n_samples: int = 100, diagonal_output: bool = False, generator: Generator | None = None, fitting: bool = False, **model_kwargs: dict[str, Any]) -> Tensor | tuple[Tensor, Tensor]

Compute the posterior predictive on input data x.

Parameters:

  • x #

    (Tensor or MutableMapping) –

    (batch_size, input_shape) if tensor. If MutableMapping, must contain the said tensor.

  • pred_type #

    (('glm', 'nn'), default: 'glm' ) –

    type of posterior predictive, linearized GLM predictive or neural network sampling predictive. The GLM predictive is consistent with the curvature approximations used here. When Laplace is done only on subset of parameters (i.e. some grad are disabled), only nn predictive is supported.

  • (('mc', 'probit', 'bridge', 'bridge_norm'), default: 'mc' ) –

    how to approximate the classification link function for the 'glm'. For pred_type='nn', only 'mc' is possible.

  • joint #

    (bool, default: False ) –

    Whether to output a joint predictive distribution in regression with pred_type='glm'. If set to True, the predictive distribution has the same form as GP posterior, i.e. N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]). If False, then only outputs the marginal predictive distribution. Only available for regression and GLM predictive.

  • n_samples #

    (int, default: 100 ) –

    number of samples for link_approx='mc'.

  • diagonal_output #

    (bool, default: False ) –

    whether to use a diagonalized posterior predictive on the outputs. Only works for pred_type='glm' when joint=False in regression. In the case of last-layer Laplace with a diagonal or Kron Hessian, setting this to True makes computation much(!) faster for large number of outputs.

  • generator #

    (Generator, default: None ) –

    random number generator to control the samples (if sampling used).

  • fitting #

    (bool, default: False ) –

    whether or not this predictive call is done during fitting. Only useful for reward modeling: the likelihood is set to "regression" when False and "classification" when True.

Returns:

  • predictive ( Tensor or tuple[Tensor] ) –

    For likelihood='classification', a torch.Tensor is returned with a distribution over classes (similar to a Softmax). For likelihood='regression', a tuple of torch.Tensor is returned with the mean and the predictive variance. For likelihood='regression' and joint=True, a tuple of torch.Tensor is returned with the mean and the predictive covariance.

Source code in laplace/baselaplace.py
def __call__(
    self,
    x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
    pred_type: PredType | str = PredType.GLM,
    joint: bool = False,
    link_approx: LinkApprox | str = LinkApprox.PROBIT,
    n_samples: int = 100,
    diagonal_output: bool = False,
    generator: torch.Generator | None = None,
    fitting: bool = False,
    **model_kwargs: dict[str, Any],
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
    """Compute the posterior predictive on input data `x`.

    Parameters
    ----------
    x : torch.Tensor or MutableMapping
        `(batch_size, input_shape)` if tensor. If MutableMapping, must contain
        the said tensor.

    pred_type : {'glm', 'nn'}, default='glm'
        type of posterior predictive, linearized GLM predictive or neural
        network sampling predictive. The GLM predictive is consistent with
        the curvature approximations used here. When Laplace is done only
        on subset of parameters (i.e. some grad are disabled),
        only `nn` predictive is supported.

    link_approx : {'mc', 'probit', 'bridge', 'bridge_norm'}
        how to approximate the classification link function for the `'glm'`.
        For `pred_type='nn'`, only 'mc' is possible.

    joint : bool
        Whether to output a joint predictive distribution in regression with
        `pred_type='glm'`. If set to `True`, the predictive distribution
        has the same form as GP posterior, i.e. N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]).
        If `False`, then only outputs the marginal predictive distribution.
        Only available for regression and GLM predictive.

    n_samples : int
        number of samples for `link_approx='mc'`.

    diagonal_output : bool
        whether to use a diagonalized posterior predictive on the outputs.
        Only works for `pred_type='glm'` when `joint=False` in regression.
        In the case of last-layer Laplace with a diagonal or Kron Hessian,
        setting this to `True` makes computation much(!) faster for large
        number of outputs.

    generator : torch.Generator, optional
        random number generator to control the samples (if sampling used).

    fitting : bool, default=False
        whether or not this predictive call is done during fitting. Only useful for
        reward modeling: the likelihood is set to `"regression"` when `False` and
        `"classification"` when `True`.

    Returns
    -------
    predictive: torch.Tensor or tuple[torch.Tensor]
        For `likelihood='classification'`, a torch.Tensor is returned with
        a distribution over classes (similar to a Softmax).
        For `likelihood='regression'`, a tuple of torch.Tensor is returned
        with the mean and the predictive variance.
        For `likelihood='regression'` and `joint=True`, a tuple of torch.Tensor
        is returned with the mean and the predictive covariance.
    """
    if pred_type not in [pred for pred in PredType]:
        raise ValueError("Only glm and nn supported as prediction types.")

    if link_approx not in [la for la in LinkApprox]:
        raise ValueError(f"Unsupported link approximation {link_approx}.")

    if pred_type == PredType.NN and link_approx != LinkApprox.MC:
        raise ValueError(
            "Only mc link approximation is supported for nn prediction type."
        )

    if generator is not None:
        if (
            not isinstance(generator, torch.Generator)
            or generator.device != self._device
        ):
            raise ValueError("Invalid random generator (check type and device).")

    likelihood = self.likelihood
    if likelihood == Likelihood.REWARD_MODELING:
        likelihood = Likelihood.CLASSIFICATION if fitting else Likelihood.REGRESSION

    if pred_type == PredType.GLM:
        return self._glm_forward_call(
            x, likelihood, joint, link_approx, n_samples, diagonal_output
        )
    else:
        if likelihood == Likelihood.REGRESSION:
            samples = self._nn_predictive_samples(x, n_samples, **model_kwargs)
            return samples.mean(dim=0), samples.var(dim=0)
        else:  # classification; the average is computed online
            return self._nn_predictive_classification(x, n_samples, **model_kwargs)

_glm_forward_call #

_glm_forward_call(x: Tensor | MutableMapping, likelihood: Likelihood | str, joint: bool = False, link_approx: LinkApprox | str = PROBIT, n_samples: int = 100, diagonal_output: bool = False) -> Tensor | tuple[Tensor, Tensor]

Compute the posterior predictive on input data x for "glm" pred type.

Parameters:

  • x #

    (Tensor or MutableMapping) –

    (batch_size, input_shape) if tensor. If MutableMapping, must contain the said tensor.

  • likelihood #

    (Likelihood or str in {'classification', 'regression', 'reward_modeling'}) –

    determines the log likelihood Hessian approximation.

  • (('mc', 'probit', 'bridge', 'bridge_norm'), default: 'mc' ) –

    how to approximate the classification link function for the 'glm'. For pred_type='nn', only 'mc' is possible.

  • joint #

    (bool, default: False ) –

    Whether to output a joint predictive distribution in regression with pred_type='glm'. If set to True, the predictive distribution has the same form as GP posterior, i.e. N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]). If False, then only outputs the marginal predictive distribution. Only available for regression and GLM predictive.

  • n_samples #

    (int, default: 100 ) –

    number of samples for link_approx='mc'.

  • diagonal_output #

    (bool, default: False ) –

    whether to use a diagonalized posterior predictive on the outputs. Only works for pred_type='glm' and link_approx='mc'.

Returns:

  • predictive ( Tensor or tuple[Tensor] ) –

    For likelihood='classification', a torch.Tensor is returned with a distribution over classes (similar to a Softmax). For likelihood='regression', a tuple of torch.Tensor is returned with the mean and the predictive variance. For likelihood='regression' and joint=True, a tuple of torch.Tensor is returned with the mean and the predictive covariance.

Source code in laplace/baselaplace.py
def _glm_forward_call(
    self,
    x: torch.Tensor | MutableMapping,
    likelihood: Likelihood | str,
    joint: bool = False,
    link_approx: LinkApprox | str = LinkApprox.PROBIT,
    n_samples: int = 100,
    diagonal_output: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
    """Compute the posterior predictive on input data `x` for "glm" pred type.

    Parameters
    ----------
    x : torch.Tensor or MutableMapping
        `(batch_size, input_shape)` if tensor. If MutableMapping, must contain
        the said tensor.

    likelihood : Likelihood or str in {'classification', 'regression', 'reward_modeling'}
        determines the log likelihood Hessian approximation.

    link_approx : {'mc', 'probit', 'bridge', 'bridge_norm'}
        how to approximate the classification link function for the `'glm'`.
        For `pred_type='nn'`, only 'mc' is possible.

    joint : bool
        Whether to output a joint predictive distribution in regression with
        `pred_type='glm'`. If set to `True`, the predictive distribution
        has the same form as GP posterior, i.e. N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]).
        If `False`, then only outputs the marginal predictive distribution.
        Only available for regression and GLM predictive.

    n_samples : int
        number of samples for `link_approx='mc'`.

    diagonal_output : bool
        whether to use a diagonalized posterior predictive on the outputs.
        Only works for `pred_type='glm'` and `link_approx='mc'`.

    Returns
    -------
    predictive: torch.Tensor or tuple[torch.Tensor]
        For `likelihood='classification'`, a torch.Tensor is returned with
        a distribution over classes (similar to a Softmax).
        For `likelihood='regression'`, a tuple of torch.Tensor is returned
        with the mean and the predictive variance.
        For `likelihood='regression'` and `joint=True`, a tuple of torch.Tensor
        is returned with the mean and the predictive covariance.
    """
    f_mu, f_var = self._glm_predictive_distribution(
        x, joint=joint and likelihood == Likelihood.REGRESSION
    )

    if likelihood == Likelihood.REGRESSION:
        if diagonal_output and not joint:
            f_var = torch.diagonal(f_var, dim1=-2, dim2=-1)
        return f_mu, f_var

    if link_approx == LinkApprox.MC:
        return self._glm_predictive_samples(
            f_mu,
            f_var,
            n_samples=n_samples,
            diagonal_output=diagonal_output,
        ).mean(dim=0)
    elif link_approx == LinkApprox.PROBIT:
        kappa = 1 / torch.sqrt(1.0 + np.pi / 8 * f_var.diagonal(dim1=1, dim2=2))
        return torch.softmax(kappa * f_mu, dim=-1)
    elif "bridge" in link_approx:
        # zero mean correction
        f_mu -= (
            f_var.sum(-1)
            * f_mu.sum(-1).reshape(-1, 1)
            / f_var.sum(dim=(1, 2)).reshape(-1, 1)
        )
        f_var -= torch.einsum(
            "bi,bj->bij", f_var.sum(-1), f_var.sum(-2)
        ) / f_var.sum(dim=(1, 2)).reshape(-1, 1, 1)

        # Laplace Bridge
        _, K = f_mu.size(0), f_mu.size(-1)
        f_var_diag = torch.diagonal(f_var, dim1=1, dim2=2)

        # optional: variance correction
        if link_approx == LinkApprox.BRIDGE_NORM:
            f_var_diag_mean = f_var_diag.mean(dim=1)
            f_var_diag_mean /= torch.as_tensor([K / 2], device=self._device).sqrt()
            f_mu /= f_var_diag_mean.sqrt().unsqueeze(-1)
            f_var_diag /= f_var_diag_mean.unsqueeze(-1)

        sum_exp = torch.exp(-f_mu).sum(dim=1).unsqueeze(-1)
        alpha = (1 - 2 / K + f_mu.exp() / K**2 * sum_exp) / f_var_diag
        return torch.nan_to_num(alpha / alpha.sum(dim=1).unsqueeze(-1), nan=1.0)
    else:
        raise ValueError(
            "Prediction path invalid. Check the likelihood, pred_type, link_approx combination!"
        )

_glm_predictive_samples #

_glm_predictive_samples(f_mu: Tensor, f_var: Tensor, n_samples: int, diagonal_output: bool = False, generator: Generator | None = None) -> Tensor

Sample from the posterior predictive on input data x using "glm" prediction type.

Parameters:

  • f_mu #

    (Tensor or MutableMapping) –

    glm predictive mean (batch_size, output_shape)

  • f_var #

    (Tensor or MutableMapping) –

    glm predictive covariances (batch_size, output_shape, output_shape)

  • n_samples #

    (int) –

    number of samples

  • diagonal_output #

    (bool, default: False ) –

    whether to use a diagonalized glm posterior predictive on the outputs.

  • generator #

    (Generator, default: None ) –

    random number generator to control the samples (if sampling used)

Returns:

  • samples ( Tensor ) –

    samples (n_samples, batch_size, output_shape)

Source code in laplace/baselaplace.py
def _glm_predictive_samples(
    self,
    f_mu: torch.Tensor,
    f_var: torch.Tensor,
    n_samples: int,
    diagonal_output: bool = False,
    generator: torch.Generator | None = None,
) -> torch.Tensor:
    """Sample from the posterior predictive on input data `x` using "glm" prediction
    type.

    Parameters
    ----------
    f_mu : torch.Tensor or MutableMapping
        glm predictive mean `(batch_size, output_shape)`

    f_var : torch.Tensor or MutableMapping
        glm predictive covariances `(batch_size, output_shape, output_shape)`

    n_samples : int
        number of samples

    diagonal_output : bool
        whether to use a diagonalized glm posterior predictive on the outputs.

    generator : torch.Generator, optional
        random number generator to control the samples (if sampling used)

    Returns
    -------
    samples : torch.Tensor
        samples `(n_samples, batch_size, output_shape)`
    """
    assert f_var.shape == torch.Size([f_mu.shape[0], f_mu.shape[1], f_mu.shape[1]])

    if diagonal_output:
        f_var = torch.diagonal(f_var, dim1=1, dim2=2)

    f_samples = normal_samples(f_mu, f_var, n_samples, generator)

    if self.likelihood == Likelihood.REGRESSION:
        return f_samples
    else:
        return torch.softmax(f_samples, dim=-1)

square_norm #

square_norm(value) -> Tensor

Compute the square norm under post. Precision with value-self.mean as 𝛥:

\[ \Delta^ op P \Delta \]

Returns:

  • square_form
Source code in laplace/baselaplace.py
def square_norm(self, value) -> torch.Tensor:
    """Compute the square norm under post. Precision with `value-self.mean` as 𝛥:

    $$
        \\Delta^\top P \\Delta
    $$

    Returns
    -------
    square_form
    """
    raise NotImplementedError

log_prob #

log_prob(value: Tensor, normalized: bool = True) -> Tensor

Compute the log probability under the (current) Laplace approximation.

Parameters:

  • value #

    (Tensor) –
  • normalized #

    (bool, default: True ) –

    whether to return log of a properly normalized Gaussian or just the terms that depend on value.

Returns:

  • log_prob ( Tensor ) –
Source code in laplace/baselaplace.py
def log_prob(self, value: torch.Tensor, normalized: bool = True) -> torch.Tensor:
    """Compute the log probability under the (current) Laplace approximation.

    Parameters
    ----------
    value: torch.Tensor
    normalized : bool, default=True
        whether to return log of a properly normalized Gaussian or just the
        terms that depend on `value`.

    Returns
    -------
    log_prob : torch.Tensor
    """
    if not normalized:
        return -self.square_norm(value) / 2
    log_prob = (
        -self.n_params / 2 * log(2 * pi) + self.log_det_posterior_precision / 2
    )
    log_prob -= self.square_norm(value) / 2
    return log_prob

predictive_samples #

predictive_samples(x: Tensor | MutableMapping[str, Tensor | Any], pred_type: PredType | str = GLM, n_samples: int = 100, diagonal_output: bool = False, generator: Generator | None = None) -> Tensor

Sample from the posterior predictive on input data x. Can be used, for example, for Thompson sampling.

Parameters:

  • x #

    (Tensor or MutableMapping) –

    input data (batch_size, input_shape)

  • pred_type #

    (('glm', 'nn'), default: 'glm' ) –

    type of posterior predictive, linearized GLM predictive or neural network sampling predictive. The GLM predictive is consistent with the curvature approximations used here.

  • n_samples #

    (int, default: 100 ) –

    number of samples

  • diagonal_output #

    (bool, default: False ) –

    whether to use a diagonalized glm posterior predictive on the outputs. Only applies when pred_type='glm'.

  • generator #

    (Generator, default: None ) –

    random number generator to control the samples (if sampling used)

Returns:

  • samples ( Tensor ) –

    samples (n_samples, batch_size, output_shape)

Source code in laplace/baselaplace.py
def predictive_samples(
    self,
    x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
    pred_type: PredType | str = PredType.GLM,
    n_samples: int = 100,
    diagonal_output: bool = False,
    generator: torch.Generator | None = None,
) -> torch.Tensor:
    """Sample from the posterior predictive on input data `x`.
    Can be used, for example, for Thompson sampling.

    Parameters
    ----------
    x : torch.Tensor or MutableMapping
        input data `(batch_size, input_shape)`

    pred_type : {'glm', 'nn'}, default='glm'
        type of posterior predictive, linearized GLM predictive or neural
        network sampling predictive. The GLM predictive is consistent with
        the curvature approximations used here.

    n_samples : int
        number of samples

    diagonal_output : bool
        whether to use a diagonalized glm posterior predictive on the outputs.
        Only applies when `pred_type='glm'`.

    generator : torch.Generator, optional
        random number generator to control the samples (if sampling used)

    Returns
    -------
    samples : torch.Tensor
        samples `(n_samples, batch_size, output_shape)`
    """
    if pred_type not in PredType.__members__.values():
        raise ValueError("Only glm and nn supported as prediction types.")

    if pred_type == PredType.GLM:
        f_mu, f_var = self._glm_predictive_distribution(x)
        return self._glm_predictive_samples(
            f_mu, f_var, n_samples, diagonal_output, generator
        )

    else:  # 'nn'
        return self._nn_predictive_samples(x, n_samples, generator)

functional_variance #

functional_variance(Js: Tensor) -> Tensor

Compute functional variance for the 'glm' predictive: f_var[i] = Js[i] @ P.inv() @ Js[i].T, which is a output x output predictive covariance matrix. Mathematically, we have for a single Jacobian \(\mathcal{J} = \nabla_\theta f(x;\theta)\vert_{\theta_{MAP}}\) the output covariance matrix \( \mathcal{J} P^{-1} \mathcal{J}^T \).

Parameters:

  • Js #

    (Tensor) –

    Jacobians of model output wrt parameters (batch, outputs, parameters)

Returns:

  • f_var ( Tensor ) –

    output covariance (batch, outputs, outputs)

Source code in laplace/baselaplace.py
def functional_variance(self, Js: torch.Tensor) -> torch.Tensor:
    """Compute functional variance for the `'glm'` predictive:
    `f_var[i] = Js[i] @ P.inv() @ Js[i].T`, which is a output x output
    predictive covariance matrix.
    Mathematically, we have for a single Jacobian
    \\(\\mathcal{J} = \\nabla_\\theta f(x;\\theta)\\vert_{\\theta_{MAP}}\\)
    the output covariance matrix
    \\( \\mathcal{J} P^{-1} \\mathcal{J}^T \\).

    Parameters
    ----------
    Js : torch.Tensor
        Jacobians of model output wrt parameters
        `(batch, outputs, parameters)`

    Returns
    -------
    f_var : torch.Tensor
        output covariance `(batch, outputs, outputs)`
    """
    raise NotImplementedError

functional_covariance #

functional_covariance(Js: Tensor) -> Tensor

Compute functional covariance for the 'glm' predictive: f_cov = Js @ P.inv() @ Js.T, which is a batchoutput x batchoutput predictive covariance matrix.

This emulates the GP posterior covariance N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]). Useful for joint predictions, such as in batched Bayesian optimization.

Parameters:

  • Js #

    (Tensor) –

    Jacobians of model output wrt parameters (batch*outputs, parameters)

Returns:

  • f_cov ( Tensor ) –

    output covariance (batch*outputs, batch*outputs)

Source code in laplace/baselaplace.py
def functional_covariance(self, Js: torch.Tensor) -> torch.Tensor:
    """Compute functional covariance for the `'glm'` predictive:
    `f_cov = Js @ P.inv() @ Js.T`, which is a batch*output x batch*output
    predictive covariance matrix.

    This emulates the GP posterior covariance N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]).
    Useful for joint predictions, such as in batched Bayesian optimization.

    Parameters
    ----------
    Js : torch.Tensor
        Jacobians of model output wrt parameters
        `(batch*outputs, parameters)`

    Returns
    -------
    f_cov : torch.Tensor
        output covariance `(batch*outputs, batch*outputs)`
    """
    raise NotImplementedError

sample #

sample(n_samples: int = 100, generator: Generator | None = None) -> Tensor

Sample from the Laplace posterior approximation, i.e., \( \theta \sim \mathcal{N}(\theta_{MAP}, P^{-1})\).

Parameters:

  • n_samples #

    (int, default: 100 ) –

    number of samples

  • generator #

    (Generator, default: None ) –

    random number generator to control the samples

Returns:

  • samples ( Tensor ) –
Source code in laplace/baselaplace.py
def sample(
    self, n_samples: int = 100, generator: torch.Generator | None = None
) -> torch.Tensor:
    """Sample from the Laplace posterior approximation, i.e.,
    \\( \\theta \\sim \\mathcal{N}(\\theta_{MAP}, P^{-1})\\).

    Parameters
    ----------
    n_samples : int, default=100
        number of samples

    generator : torch.Generator, optional
        random number generator to control the samples

    Returns
    -------
    samples: torch.Tensor
    """
    raise NotImplementedError

fit #

fit(train_loader: DataLoader, override: bool = True, progress_bar: bool = False) -> None

Fit the local Laplace approximation at the parameters of the model.

Parameters:

  • train_loader #

    (DataLoader) –

    each iterate is a training batch, either (X, y) tensors or a dict-like object containing keys as expressed by self.dict_key_x and self.dict_key_y. train_loader.dataset needs to be set to access \(N\), size of the data set.

  • override #

    (bool, default: True ) –

    whether to initialize H, loss, and n_data again; setting to False is useful for online learning settings to accumulate a sequential posterior approximation.

  • progress_bar #

    (bool, default: False ) –
Source code in laplace/lllaplace.py
def fit(
    self,
    train_loader: DataLoader,
    override: bool = True,
    progress_bar: bool = False,
) -> None:
    """Fit the local Laplace approximation at the parameters of the model.

    Parameters
    ----------
    train_loader : torch.data.utils.DataLoader
        each iterate is a training batch, either `(X, y)` tensors or a dict-like
        object containing keys as expressed by `self.dict_key_x` and
        `self.dict_key_y`. `train_loader.dataset` needs to be set to access
        \\(N\\), size of the data set.
    override : bool, default=True
        whether to initialize H, loss, and n_data again; setting to False is useful for
        online learning settings to accumulate a sequential posterior approximation.
    progress_bar: bool, default=False
    """
    if not override:
        raise ValueError(
            "Last-layer Laplace approximations do not support `override=False`."
        )

    self.model.eval()

    if self.model.last_layer is None:
        self.data: tuple[torch.Tensor, torch.Tensor] | MutableMapping = next(
            iter(train_loader)
        )
        self._find_last_layer(self.data)
        params: torch.Tensor = parameters_to_vector(
            self.model.last_layer.parameters()
        ).detach()
        self.n_params: int = len(params)
        self.n_layers: int = len(list(self.model.last_layer.parameters()))
        # here, check the already set prior precision again
        self.prior_precision: float | torch.Tensor = self._prior_precision
        self.prior_mean: float | torch.Tensor = self._prior_mean
        self._init_H()

    super().fit(train_loader, override=override)
    self.mean: torch.Tensor = parameters_to_vector(
        self.model.last_layer.parameters()
    )

    if not self.enable_backprop:
        self.mean = self.mean.detach()

functional_variance_fast #

functional_variance_fast(X)

Should be overriden if there exists a trick to make this fast!

Parameters:

  • X #

Returns:

  • f_var_diag ( torch.Tensor of shape (batch_size, num_outputs) ) –

    Corresponding to the diagonal of the covariance matrix of the outputs

Source code in laplace/lllaplace.py
def functional_variance_fast(self, X):
    """
    Should be overriden if there exists a trick to make this fast!

    Parameters
    ----------
    X: torch.Tensor of shape (batch_size, input_dim)

    Returns
    -------
    f_var_diag: torch.Tensor of shape (batch_size, num_outputs)
        Corresponding to the diagonal of the covariance matrix of the outputs
    """
    Js, f_mu = self.backend.last_layer_jacobians(X, self.enable_backprop)
    f_cov = self.functional_variance(Js)  # No trick possible for Full Laplace
    f_var = torch.diagonal(f_cov, dim1=-2, dim2=-1)
    return f_mu, f_var

DiagLLLaplace #

DiagLLLaplace(model: Module, likelihood: Likelihood | str, sigma_noise: float | Tensor = 1.0, prior_precision: float | Tensor = 1.0, prior_mean: float | Tensor = 0.0, temperature: float = 1.0, enable_backprop: bool = False, feature_reduction: FeatureReduction | str | None = None, dict_key_x: str = 'input_ids', dict_key_y: str = 'labels', backend: type[CurvatureInterface] | None = None, last_layer_name: str | None = None, backend_kwargs: dict[str, Any] | None = None, asdl_fisher_kwargs: dict[str, Any] | None = None)

Bases: LLLaplace, DiagLaplace

Last-layer Laplace approximation with diagonal log likelihood Hessian approximation and hence posterior precision. Mathematically, we have \(P \approx \textrm{diag}(P)\). See DiagLaplace, LLLaplace, and BaseLaplace for the full interface.

Methods:

  • fit

    Fit the local Laplace approximation at the parameters of the model.

  • log_marginal_likelihood

    Compute the Laplace approximation to the log marginal likelihood subject

  • __call__

    Compute the posterior predictive on input data x.

  • log_prob

    Compute the log probability under the (current) Laplace approximation.

  • predictive_samples

    Sample from the posterior predictive on input data x.

Attributes:

  • log_likelihood (Tensor) –

    Compute log likelihood on the training data after .fit() has been called.

  • prior_precision_diag (Tensor) –

    Obtain the diagonal prior precision \(p_0\) constructed from either

  • scatter (Tensor) –

    Computes the scatter, a term of the log marginal likelihood that

  • log_det_prior_precision (Tensor) –

    Compute log determinant of the prior precision

  • log_det_ratio (Tensor) –

    Compute the log determinant ratio, a part of the log marginal likelihood.

  • posterior_precision (Tensor) –

    Diagonal posterior precision \(p\).

  • posterior_scale (Tensor) –

    Diagonal posterior scale \(\sqrt{p^{-1}}\).

  • posterior_variance (Tensor) –

    Diagonal posterior variance \(p^{-1}\).

Source code in laplace/lllaplace.py
def __init__(
    self,
    model: nn.Module,
    likelihood: Likelihood | str,
    sigma_noise: float | torch.Tensor = 1.0,
    prior_precision: float | torch.Tensor = 1.0,
    prior_mean: float | torch.Tensor = 0.0,
    temperature: float = 1.0,
    enable_backprop: bool = False,
    feature_reduction: FeatureReduction | str | None = None,
    dict_key_x: str = "input_ids",
    dict_key_y: str = "labels",
    backend: type[CurvatureInterface] | None = None,
    last_layer_name: str | None = None,
    backend_kwargs: dict[str, Any] | None = None,
    asdl_fisher_kwargs: dict[str, Any] | None = None,
):
    if asdl_fisher_kwargs is not None:
        raise ValueError("Last-layer Laplace does not support asdl_fisher_kwargs.")

    self.H = None
    super().__init__(
        model,
        likelihood,
        sigma_noise=sigma_noise,
        prior_precision=1.0,
        prior_mean=0.0,
        temperature=temperature,
        enable_backprop=enable_backprop,
        dict_key_x=dict_key_x,
        dict_key_y=dict_key_y,
        backend=backend,
        backend_kwargs=backend_kwargs,
    )
    self.model = FeatureExtractor(
        deepcopy(model),
        last_layer_name=last_layer_name,
        enable_backprop=enable_backprop,
        feature_reduction=feature_reduction,
    )

    if self.model.last_layer is None:
        self.mean: torch.Tensor | None = None
        self.n_params: int | None = None
        self.n_layers: int | None = None
        # ignore checks of prior mean setter temporarily, check on .fit()
        self._prior_precision: float | torch.Tensor = prior_precision
        self._prior_mean: float | torch.Tensor = prior_mean
    else:
        self.n_params: int = len(
            parameters_to_vector(self.model.last_layer.parameters())
        )
        self.n_layers: int | None = len(list(self.model.last_layer.parameters()))
        self.prior_precision: float | torch.Tensor = prior_precision
        self.prior_mean: float | torch.Tensor = prior_mean
        self.mean: float | torch.Tensor = self.prior_mean
        self._init_H()

    self._backend_kwargs["last_layer"] = True
    self._last_layer_name: str | None = last_layer_name

log_likelihood #

log_likelihood: Tensor

Compute log likelihood on the training data after .fit() has been called. The log likelihood is computed on-demand based on the loss and, for example, the observation noise which makes it differentiable in the latter for iterative updates.

Returns:

  • log_likelihood ( Tensor ) –

prior_precision_diag #

prior_precision_diag: Tensor

Obtain the diagonal prior precision \(p_0\) constructed from either a scalar or diagonal prior precision.

Returns:

  • prior_precision_diag ( Tensor ) –

scatter #

scatter: Tensor

Computes the scatter, a term of the log marginal likelihood that corresponds to L-2 regularization: scatter = \((\theta_{MAP} - \mu_0)^{T} P_0 (\theta_{MAP} - \mu_0) \).

Returns:

  • scatter ( Tensor ) –

log_det_prior_precision #

log_det_prior_precision: Tensor

Compute log determinant of the prior precision \(\log \det P_0\)

Returns:

  • log_det ( Tensor ) –

log_det_ratio #

log_det_ratio: Tensor

Compute the log determinant ratio, a part of the log marginal likelihood.

\[ \log \frac{\det P}{\det P_0} = \log \det P - \log \det P_0 \]

Returns:

  • log_det_ratio ( Tensor ) –

posterior_precision #

posterior_precision: Tensor

Diagonal posterior precision \(p\).

Returns:

  • precision ( tensor ) –

    (parameters)

posterior_scale #

posterior_scale: Tensor

Diagonal posterior scale \(\sqrt{p^{-1}}\).

Returns:

  • precision ( tensor ) –

    (parameters)

posterior_variance #

posterior_variance: Tensor

Diagonal posterior variance \(p^{-1}\).

Returns:

  • precision ( tensor ) –

    (parameters)

fit #

fit(train_loader: DataLoader, override: bool = True, progress_bar: bool = False) -> None

Fit the local Laplace approximation at the parameters of the model.

Parameters:

  • train_loader #

    (DataLoader) –

    each iterate is a training batch, either (X, y) tensors or a dict-like object containing keys as expressed by self.dict_key_x and self.dict_key_y. train_loader.dataset needs to be set to access \(N\), size of the data set.

  • override #

    (bool, default: True ) –

    whether to initialize H, loss, and n_data again; setting to False is useful for online learning settings to accumulate a sequential posterior approximation.

  • progress_bar #

    (bool, default: False ) –
Source code in laplace/lllaplace.py
def fit(
    self,
    train_loader: DataLoader,
    override: bool = True,
    progress_bar: bool = False,
) -> None:
    """Fit the local Laplace approximation at the parameters of the model.

    Parameters
    ----------
    train_loader : torch.data.utils.DataLoader
        each iterate is a training batch, either `(X, y)` tensors or a dict-like
        object containing keys as expressed by `self.dict_key_x` and
        `self.dict_key_y`. `train_loader.dataset` needs to be set to access
        \\(N\\), size of the data set.
    override : bool, default=True
        whether to initialize H, loss, and n_data again; setting to False is useful for
        online learning settings to accumulate a sequential posterior approximation.
    progress_bar: bool, default=False
    """
    if not override:
        raise ValueError(
            "Last-layer Laplace approximations do not support `override=False`."
        )

    self.model.eval()

    if self.model.last_layer is None:
        self.data: tuple[torch.Tensor, torch.Tensor] | MutableMapping = next(
            iter(train_loader)
        )
        self._find_last_layer(self.data)
        params: torch.Tensor = parameters_to_vector(
            self.model.last_layer.parameters()
        ).detach()
        self.n_params: int = len(params)
        self.n_layers: int = len(list(self.model.last_layer.parameters()))
        # here, check the already set prior precision again
        self.prior_precision: float | torch.Tensor = self._prior_precision
        self.prior_mean: float | torch.Tensor = self._prior_mean
        self._init_H()

    super().fit(train_loader, override=override)
    self.mean: torch.Tensor = parameters_to_vector(
        self.model.last_layer.parameters()
    )

    if not self.enable_backprop:
        self.mean = self.mean.detach()

log_marginal_likelihood #

log_marginal_likelihood(prior_precision: Tensor | None = None, sigma_noise: Tensor | None = None) -> Tensor

Compute the Laplace approximation to the log marginal likelihood subject to specific Hessian approximations that subclasses implement. Requires that the Laplace approximation has been fit before. The resulting torch.Tensor is differentiable in prior_precision and sigma_noise if these have gradients enabled. By passing prior_precision or sigma_noise, the current value is overwritten. This is useful for iterating on the log marginal likelihood.

Parameters:

  • prior_precision #

    (Tensor, default: None ) –

    prior precision if should be changed from current prior_precision value

  • sigma_noise #

    (Tensor, default: None ) –

    observation noise standard deviation if should be changed

Returns:

  • log_marglik ( Tensor ) –
Source code in laplace/baselaplace.py
def log_marginal_likelihood(
    self,
    prior_precision: torch.Tensor | None = None,
    sigma_noise: torch.Tensor | None = None,
) -> torch.Tensor:
    """Compute the Laplace approximation to the log marginal likelihood subject
    to specific Hessian approximations that subclasses implement.
    Requires that the Laplace approximation has been fit before.
    The resulting torch.Tensor is differentiable in `prior_precision` and
    `sigma_noise` if these have gradients enabled.
    By passing `prior_precision` or `sigma_noise`, the current value is
    overwritten. This is useful for iterating on the log marginal likelihood.

    Parameters
    ----------
    prior_precision : torch.Tensor, optional
        prior precision if should be changed from current `prior_precision` value
    sigma_noise : torch.Tensor, optional
        observation noise standard deviation if should be changed

    Returns
    -------
    log_marglik : torch.Tensor
    """
    # update prior precision (useful when iterating on marglik)
    if prior_precision is not None:
        self.prior_precision = prior_precision

    # update sigma_noise (useful when iterating on marglik)
    if sigma_noise is not None:
        if self.likelihood != Likelihood.REGRESSION:
            raise ValueError("Can only change sigma_noise for regression.")

        self.sigma_noise = sigma_noise

    return self.log_likelihood - 0.5 * (self.log_det_ratio + self.scatter)

__call__ #

__call__(x: Tensor | MutableMapping[str, Tensor | Any], pred_type: PredType | str = GLM, joint: bool = False, link_approx: LinkApprox | str = PROBIT, n_samples: int = 100, diagonal_output: bool = False, generator: Generator | None = None, fitting: bool = False, **model_kwargs: dict[str, Any]) -> Tensor | tuple[Tensor, Tensor]

Compute the posterior predictive on input data x.

Parameters:

  • x #

    (Tensor or MutableMapping) –

    (batch_size, input_shape) if tensor. If MutableMapping, must contain the said tensor.

  • pred_type #

    (('glm', 'nn'), default: 'glm' ) –

    type of posterior predictive, linearized GLM predictive or neural network sampling predictive. The GLM predictive is consistent with the curvature approximations used here. When Laplace is done only on subset of parameters (i.e. some grad are disabled), only nn predictive is supported.

  • (('mc', 'probit', 'bridge', 'bridge_norm'), default: 'mc' ) –

    how to approximate the classification link function for the 'glm'. For pred_type='nn', only 'mc' is possible.

  • joint #

    (bool, default: False ) –

    Whether to output a joint predictive distribution in regression with pred_type='glm'. If set to True, the predictive distribution has the same form as GP posterior, i.e. N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]). If False, then only outputs the marginal predictive distribution. Only available for regression and GLM predictive.

  • n_samples #

    (int, default: 100 ) –

    number of samples for link_approx='mc'.

  • diagonal_output #

    (bool, default: False ) –

    whether to use a diagonalized posterior predictive on the outputs. Only works for pred_type='glm' when joint=False in regression. In the case of last-layer Laplace with a diagonal or Kron Hessian, setting this to True makes computation much(!) faster for large number of outputs.

  • generator #

    (Generator, default: None ) –

    random number generator to control the samples (if sampling used).

  • fitting #

    (bool, default: False ) –

    whether or not this predictive call is done during fitting. Only useful for reward modeling: the likelihood is set to "regression" when False and "classification" when True.

Returns:

  • predictive ( Tensor or tuple[Tensor] ) –

    For likelihood='classification', a torch.Tensor is returned with a distribution over classes (similar to a Softmax). For likelihood='regression', a tuple of torch.Tensor is returned with the mean and the predictive variance. For likelihood='regression' and joint=True, a tuple of torch.Tensor is returned with the mean and the predictive covariance.

Source code in laplace/baselaplace.py
def __call__(
    self,
    x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
    pred_type: PredType | str = PredType.GLM,
    joint: bool = False,
    link_approx: LinkApprox | str = LinkApprox.PROBIT,
    n_samples: int = 100,
    diagonal_output: bool = False,
    generator: torch.Generator | None = None,
    fitting: bool = False,
    **model_kwargs: dict[str, Any],
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
    """Compute the posterior predictive on input data `x`.

    Parameters
    ----------
    x : torch.Tensor or MutableMapping
        `(batch_size, input_shape)` if tensor. If MutableMapping, must contain
        the said tensor.

    pred_type : {'glm', 'nn'}, default='glm'
        type of posterior predictive, linearized GLM predictive or neural
        network sampling predictive. The GLM predictive is consistent with
        the curvature approximations used here. When Laplace is done only
        on subset of parameters (i.e. some grad are disabled),
        only `nn` predictive is supported.

    link_approx : {'mc', 'probit', 'bridge', 'bridge_norm'}
        how to approximate the classification link function for the `'glm'`.
        For `pred_type='nn'`, only 'mc' is possible.

    joint : bool
        Whether to output a joint predictive distribution in regression with
        `pred_type='glm'`. If set to `True`, the predictive distribution
        has the same form as GP posterior, i.e. N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]).
        If `False`, then only outputs the marginal predictive distribution.
        Only available for regression and GLM predictive.

    n_samples : int
        number of samples for `link_approx='mc'`.

    diagonal_output : bool
        whether to use a diagonalized posterior predictive on the outputs.
        Only works for `pred_type='glm'` when `joint=False` in regression.
        In the case of last-layer Laplace with a diagonal or Kron Hessian,
        setting this to `True` makes computation much(!) faster for large
        number of outputs.

    generator : torch.Generator, optional
        random number generator to control the samples (if sampling used).

    fitting : bool, default=False
        whether or not this predictive call is done during fitting. Only useful for
        reward modeling: the likelihood is set to `"regression"` when `False` and
        `"classification"` when `True`.

    Returns
    -------
    predictive: torch.Tensor or tuple[torch.Tensor]
        For `likelihood='classification'`, a torch.Tensor is returned with
        a distribution over classes (similar to a Softmax).
        For `likelihood='regression'`, a tuple of torch.Tensor is returned
        with the mean and the predictive variance.
        For `likelihood='regression'` and `joint=True`, a tuple of torch.Tensor
        is returned with the mean and the predictive covariance.
    """
    if pred_type not in [pred for pred in PredType]:
        raise ValueError("Only glm and nn supported as prediction types.")

    if link_approx not in [la for la in LinkApprox]:
        raise ValueError(f"Unsupported link approximation {link_approx}.")

    if pred_type == PredType.NN and link_approx != LinkApprox.MC:
        raise ValueError(
            "Only mc link approximation is supported for nn prediction type."
        )

    if generator is not None:
        if (
            not isinstance(generator, torch.Generator)
            or generator.device != self._device
        ):
            raise ValueError("Invalid random generator (check type and device).")

    likelihood = self.likelihood
    if likelihood == Likelihood.REWARD_MODELING:
        likelihood = Likelihood.CLASSIFICATION if fitting else Likelihood.REGRESSION

    if pred_type == PredType.GLM:
        return self._glm_forward_call(
            x, likelihood, joint, link_approx, n_samples, diagonal_output
        )
    else:
        if likelihood == Likelihood.REGRESSION:
            samples = self._nn_predictive_samples(x, n_samples, **model_kwargs)
            return samples.mean(dim=0), samples.var(dim=0)
        else:  # classification; the average is computed online
            return self._nn_predictive_classification(x, n_samples, **model_kwargs)

_glm_forward_call #

_glm_forward_call(x: Tensor | MutableMapping, likelihood: Likelihood | str, joint: bool = False, link_approx: LinkApprox | str = PROBIT, n_samples: int = 100, diagonal_output: bool = False) -> Tensor | tuple[Tensor, Tensor]

Compute the posterior predictive on input data x for "glm" pred type.

Parameters:

  • x #

    (Tensor or MutableMapping) –

    (batch_size, input_shape) if tensor. If MutableMapping, must contain the said tensor.

  • likelihood #

    (Likelihood or str in {'classification', 'regression', 'reward_modeling'}) –

    determines the log likelihood Hessian approximation.

  • (('mc', 'probit', 'bridge', 'bridge_norm'), default: 'mc' ) –

    how to approximate the classification link function for the 'glm'. For pred_type='nn', only 'mc' is possible.

  • joint #

    (bool, default: False ) –

    Whether to output a joint predictive distribution in regression with pred_type='glm'. If set to True, the predictive distribution has the same form as GP posterior, i.e. N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]). If False, then only outputs the marginal predictive distribution. Only available for regression and GLM predictive.

  • n_samples #

    (int, default: 100 ) –

    number of samples for link_approx='mc'.

  • diagonal_output #

    (bool, default: False ) –

    whether to use a diagonalized posterior predictive on the outputs. Only works for pred_type='glm' and link_approx='mc'.

Returns:

  • predictive ( Tensor or tuple[Tensor] ) –

    For likelihood='classification', a torch.Tensor is returned with a distribution over classes (similar to a Softmax). For likelihood='regression', a tuple of torch.Tensor is returned with the mean and the predictive variance. For likelihood='regression' and joint=True, a tuple of torch.Tensor is returned with the mean and the predictive covariance.

Source code in laplace/baselaplace.py
def _glm_forward_call(
    self,
    x: torch.Tensor | MutableMapping,
    likelihood: Likelihood | str,
    joint: bool = False,
    link_approx: LinkApprox | str = LinkApprox.PROBIT,
    n_samples: int = 100,
    diagonal_output: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
    """Compute the posterior predictive on input data `x` for "glm" pred type.

    Parameters
    ----------
    x : torch.Tensor or MutableMapping
        `(batch_size, input_shape)` if tensor. If MutableMapping, must contain
        the said tensor.

    likelihood : Likelihood or str in {'classification', 'regression', 'reward_modeling'}
        determines the log likelihood Hessian approximation.

    link_approx : {'mc', 'probit', 'bridge', 'bridge_norm'}
        how to approximate the classification link function for the `'glm'`.
        For `pred_type='nn'`, only 'mc' is possible.

    joint : bool
        Whether to output a joint predictive distribution in regression with
        `pred_type='glm'`. If set to `True`, the predictive distribution
        has the same form as GP posterior, i.e. N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]).
        If `False`, then only outputs the marginal predictive distribution.
        Only available for regression and GLM predictive.

    n_samples : int
        number of samples for `link_approx='mc'`.

    diagonal_output : bool
        whether to use a diagonalized posterior predictive on the outputs.
        Only works for `pred_type='glm'` and `link_approx='mc'`.

    Returns
    -------
    predictive: torch.Tensor or tuple[torch.Tensor]
        For `likelihood='classification'`, a torch.Tensor is returned with
        a distribution over classes (similar to a Softmax).
        For `likelihood='regression'`, a tuple of torch.Tensor is returned
        with the mean and the predictive variance.
        For `likelihood='regression'` and `joint=True`, a tuple of torch.Tensor
        is returned with the mean and the predictive covariance.
    """
    f_mu, f_var = self._glm_predictive_distribution(
        x, joint=joint and likelihood == Likelihood.REGRESSION
    )

    if likelihood == Likelihood.REGRESSION:
        if diagonal_output and not joint:
            f_var = torch.diagonal(f_var, dim1=-2, dim2=-1)
        return f_mu, f_var

    if link_approx == LinkApprox.MC:
        return self._glm_predictive_samples(
            f_mu,
            f_var,
            n_samples=n_samples,
            diagonal_output=diagonal_output,
        ).mean(dim=0)
    elif link_approx == LinkApprox.PROBIT:
        kappa = 1 / torch.sqrt(1.0 + np.pi / 8 * f_var.diagonal(dim1=1, dim2=2))
        return torch.softmax(kappa * f_mu, dim=-1)
    elif "bridge" in link_approx:
        # zero mean correction
        f_mu -= (
            f_var.sum(-1)
            * f_mu.sum(-1).reshape(-1, 1)
            / f_var.sum(dim=(1, 2)).reshape(-1, 1)
        )
        f_var -= torch.einsum(
            "bi,bj->bij", f_var.sum(-1), f_var.sum(-2)
        ) / f_var.sum(dim=(1, 2)).reshape(-1, 1, 1)

        # Laplace Bridge
        _, K = f_mu.size(0), f_mu.size(-1)
        f_var_diag = torch.diagonal(f_var, dim1=1, dim2=2)

        # optional: variance correction
        if link_approx == LinkApprox.BRIDGE_NORM:
            f_var_diag_mean = f_var_diag.mean(dim=1)
            f_var_diag_mean /= torch.as_tensor([K / 2], device=self._device).sqrt()
            f_mu /= f_var_diag_mean.sqrt().unsqueeze(-1)
            f_var_diag /= f_var_diag_mean.unsqueeze(-1)

        sum_exp = torch.exp(-f_mu).sum(dim=1).unsqueeze(-1)
        alpha = (1 - 2 / K + f_mu.exp() / K**2 * sum_exp) / f_var_diag
        return torch.nan_to_num(alpha / alpha.sum(dim=1).unsqueeze(-1), nan=1.0)
    else:
        raise ValueError(
            "Prediction path invalid. Check the likelihood, pred_type, link_approx combination!"
        )

_glm_predictive_samples #

_glm_predictive_samples(f_mu: Tensor, f_var: Tensor, n_samples: int, diagonal_output: bool = False, generator: Generator | None = None) -> Tensor

Sample from the posterior predictive on input data x using "glm" prediction type.

Parameters:

  • f_mu #

    (Tensor or MutableMapping) –

    glm predictive mean (batch_size, output_shape)

  • f_var #

    (Tensor or MutableMapping) –

    glm predictive covariances (batch_size, output_shape, output_shape)

  • n_samples #

    (int) –

    number of samples

  • diagonal_output #

    (bool, default: False ) –

    whether to use a diagonalized glm posterior predictive on the outputs.

  • generator #

    (Generator, default: None ) –

    random number generator to control the samples (if sampling used)

Returns:

  • samples ( Tensor ) –

    samples (n_samples, batch_size, output_shape)

Source code in laplace/baselaplace.py
def _glm_predictive_samples(
    self,
    f_mu: torch.Tensor,
    f_var: torch.Tensor,
    n_samples: int,
    diagonal_output: bool = False,
    generator: torch.Generator | None = None,
) -> torch.Tensor:
    """Sample from the posterior predictive on input data `x` using "glm" prediction
    type.

    Parameters
    ----------
    f_mu : torch.Tensor or MutableMapping
        glm predictive mean `(batch_size, output_shape)`

    f_var : torch.Tensor or MutableMapping
        glm predictive covariances `(batch_size, output_shape, output_shape)`

    n_samples : int
        number of samples

    diagonal_output : bool
        whether to use a diagonalized glm posterior predictive on the outputs.

    generator : torch.Generator, optional
        random number generator to control the samples (if sampling used)

    Returns
    -------
    samples : torch.Tensor
        samples `(n_samples, batch_size, output_shape)`
    """
    assert f_var.shape == torch.Size([f_mu.shape[0], f_mu.shape[1], f_mu.shape[1]])

    if diagonal_output:
        f_var = torch.diagonal(f_var, dim1=1, dim2=2)

    f_samples = normal_samples(f_mu, f_var, n_samples, generator)

    if self.likelihood == Likelihood.REGRESSION:
        return f_samples
    else:
        return torch.softmax(f_samples, dim=-1)

log_prob #

log_prob(value: Tensor, normalized: bool = True) -> Tensor

Compute the log probability under the (current) Laplace approximation.

Parameters:

  • value #

    (Tensor) –
  • normalized #

    (bool, default: True ) –

    whether to return log of a properly normalized Gaussian or just the terms that depend on value.

Returns:

  • log_prob ( Tensor ) –
Source code in laplace/baselaplace.py
def log_prob(self, value: torch.Tensor, normalized: bool = True) -> torch.Tensor:
    """Compute the log probability under the (current) Laplace approximation.

    Parameters
    ----------
    value: torch.Tensor
    normalized : bool, default=True
        whether to return log of a properly normalized Gaussian or just the
        terms that depend on `value`.

    Returns
    -------
    log_prob : torch.Tensor
    """
    if not normalized:
        return -self.square_norm(value) / 2
    log_prob = (
        -self.n_params / 2 * log(2 * pi) + self.log_det_posterior_precision / 2
    )
    log_prob -= self.square_norm(value) / 2
    return log_prob

predictive_samples #

predictive_samples(x: Tensor | MutableMapping[str, Tensor | Any], pred_type: PredType | str = GLM, n_samples: int = 100, diagonal_output: bool = False, generator: Generator | None = None) -> Tensor

Sample from the posterior predictive on input data x. Can be used, for example, for Thompson sampling.

Parameters:

  • x #

    (Tensor or MutableMapping) –

    input data (batch_size, input_shape)

  • pred_type #

    (('glm', 'nn'), default: 'glm' ) –

    type of posterior predictive, linearized GLM predictive or neural network sampling predictive. The GLM predictive is consistent with the curvature approximations used here.

  • n_samples #

    (int, default: 100 ) –

    number of samples

  • diagonal_output #

    (bool, default: False ) –

    whether to use a diagonalized glm posterior predictive on the outputs. Only applies when pred_type='glm'.

  • generator #

    (Generator, default: None ) –

    random number generator to control the samples (if sampling used)

Returns:

  • samples ( Tensor ) –

    samples (n_samples, batch_size, output_shape)

Source code in laplace/baselaplace.py
def predictive_samples(
    self,
    x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
    pred_type: PredType | str = PredType.GLM,
    n_samples: int = 100,
    diagonal_output: bool = False,
    generator: torch.Generator | None = None,
) -> torch.Tensor:
    """Sample from the posterior predictive on input data `x`.
    Can be used, for example, for Thompson sampling.

    Parameters
    ----------
    x : torch.Tensor or MutableMapping
        input data `(batch_size, input_shape)`

    pred_type : {'glm', 'nn'}, default='glm'
        type of posterior predictive, linearized GLM predictive or neural
        network sampling predictive. The GLM predictive is consistent with
        the curvature approximations used here.

    n_samples : int
        number of samples

    diagonal_output : bool
        whether to use a diagonalized glm posterior predictive on the outputs.
        Only applies when `pred_type='glm'`.

    generator : torch.Generator, optional
        random number generator to control the samples (if sampling used)

    Returns
    -------
    samples : torch.Tensor
        samples `(n_samples, batch_size, output_shape)`
    """
    if pred_type not in PredType.__members__.values():
        raise ValueError("Only glm and nn supported as prediction types.")

    if pred_type == PredType.GLM:
        f_mu, f_var = self._glm_predictive_distribution(x)
        return self._glm_predictive_samples(
            f_mu, f_var, n_samples, diagonal_output, generator
        )

    else:  # 'nn'
        return self._nn_predictive_samples(x, n_samples, generator)

KronLLLaplace #

KronLLLaplace(model: Module, likelihood: Likelihood | str, sigma_noise: float | Tensor = 1.0, prior_precision: float | Tensor = 1.0, prior_mean: float | Tensor = 0.0, temperature: float = 1.0, enable_backprop: bool = False, feature_reduction: FeatureReduction | str | None = None, dict_key_x: str = 'input_ids', dict_key_y: str = 'labels', backend: type[CurvatureInterface] | None = None, last_layer_name: str | None = None, damping: bool = False, backend_kwargs: dict[str, Any] | None = None, asdl_fisher_kwargs: dict[str, Any] | None = None)

Bases: LLLaplace, KronLaplace

Last-layer Laplace approximation with Kronecker factored log likelihood Hessian approximation and hence posterior precision. Mathematically, we have for the last parameter group, i.e., torch.nn.Linear, that \P\approx Q \otimes H. See KronLaplace, LLLaplace, and BaseLaplace for the full interface and see laplace.utils.matrix.Kron and laplace.utils.matrix.KronDecomposed for the structure of the Kronecker factors. Kron is used to aggregate factors by summing up and KronDecomposed is used to add the prior, a Hessian factor (e.g. temperature), and computing posterior covariances, marginal likelihood, etc. Use of damping is possible by initializing or setting damping=True.

Methods:

  • fit

    Fit the local Laplace approximation at the parameters of the model.

  • log_marginal_likelihood

    Compute the Laplace approximation to the log marginal likelihood subject

  • __call__

    Compute the posterior predictive on input data x.

  • log_prob

    Compute the log probability under the (current) Laplace approximation.

  • predictive_samples

    Sample from the posterior predictive on input data x.

Attributes:

  • log_likelihood (Tensor) –

    Compute log likelihood on the training data after .fit() has been called.

  • prior_precision_diag (Tensor) –

    Obtain the diagonal prior precision \(p_0\) constructed from either

  • scatter (Tensor) –

    Computes the scatter, a term of the log marginal likelihood that

  • log_det_prior_precision (Tensor) –

    Compute log determinant of the prior precision

  • log_det_ratio (Tensor) –

    Compute the log determinant ratio, a part of the log marginal likelihood.

  • posterior_precision (KronDecomposed) –

    Kronecker factored Posterior precision \(P\).

Source code in laplace/lllaplace.py
def __init__(
    self,
    model: nn.Module,
    likelihood: Likelihood | str,
    sigma_noise: float | torch.Tensor = 1.0,
    prior_precision: float | torch.Tensor = 1.0,
    prior_mean: float | torch.Tensor = 0.0,
    temperature: float = 1.0,
    enable_backprop: bool = False,
    feature_reduction: FeatureReduction | str | None = None,
    dict_key_x: str = "input_ids",
    dict_key_y: str = "labels",
    backend: type[CurvatureInterface] | None = None,
    last_layer_name: str | None = None,
    damping: bool = False,
    backend_kwargs: dict[str, Any] | None = None,
    asdl_fisher_kwargs: dict[str, Any] | None = None,
):
    self.damping = damping
    super().__init__(
        model,
        likelihood,
        sigma_noise,
        prior_precision,
        prior_mean,
        temperature,
        enable_backprop,
        feature_reduction,
        dict_key_x,
        dict_key_y,
        backend,
        last_layer_name,
        backend_kwargs,
        asdl_fisher_kwargs,
    )

log_likelihood #

log_likelihood: Tensor

Compute log likelihood on the training data after .fit() has been called. The log likelihood is computed on-demand based on the loss and, for example, the observation noise which makes it differentiable in the latter for iterative updates.

Returns:

  • log_likelihood ( Tensor ) –

prior_precision_diag #

prior_precision_diag: Tensor

Obtain the diagonal prior precision \(p_0\) constructed from either a scalar or diagonal prior precision.

Returns:

  • prior_precision_diag ( Tensor ) –

scatter #

scatter: Tensor

Computes the scatter, a term of the log marginal likelihood that corresponds to L-2 regularization: scatter = \((\theta_{MAP} - \mu_0)^{T} P_0 (\theta_{MAP} - \mu_0) \).

Returns:

  • scatter ( Tensor ) –

log_det_prior_precision #

log_det_prior_precision: Tensor

Compute log determinant of the prior precision \(\log \det P_0\)

Returns:

  • log_det ( Tensor ) –

log_det_ratio #

log_det_ratio: Tensor

Compute the log determinant ratio, a part of the log marginal likelihood.

\[ \log \frac{\det P}{\det P_0} = \log \det P - \log \det P_0 \]

Returns:

  • log_det_ratio ( Tensor ) –

posterior_precision #

posterior_precision: KronDecomposed

Kronecker factored Posterior precision \(P\).

Returns:

  • precision ( `laplace.utils.matrix.KronDecomposed` ) –

fit #

fit(train_loader: DataLoader, override: bool = True, progress_bar: bool = False) -> None

Fit the local Laplace approximation at the parameters of the model.

Parameters:

  • train_loader #

    (DataLoader) –

    each iterate is a training batch, either (X, y) tensors or a dict-like object containing keys as expressed by self.dict_key_x and self.dict_key_y. train_loader.dataset needs to be set to access \(N\), size of the data set.

  • override #

    (bool, default: True ) –

    whether to initialize H, loss, and n_data again; setting to False is useful for online learning settings to accumulate a sequential posterior approximation.

  • progress_bar #

    (bool, default: False ) –
Source code in laplace/lllaplace.py
def fit(
    self,
    train_loader: DataLoader,
    override: bool = True,
    progress_bar: bool = False,
) -> None:
    """Fit the local Laplace approximation at the parameters of the model.

    Parameters
    ----------
    train_loader : torch.data.utils.DataLoader
        each iterate is a training batch, either `(X, y)` tensors or a dict-like
        object containing keys as expressed by `self.dict_key_x` and
        `self.dict_key_y`. `train_loader.dataset` needs to be set to access
        \\(N\\), size of the data set.
    override : bool, default=True
        whether to initialize H, loss, and n_data again; setting to False is useful for
        online learning settings to accumulate a sequential posterior approximation.
    progress_bar: bool, default=False
    """
    if not override:
        raise ValueError(
            "Last-layer Laplace approximations do not support `override=False`."
        )

    self.model.eval()

    if self.model.last_layer is None:
        self.data: tuple[torch.Tensor, torch.Tensor] | MutableMapping = next(
            iter(train_loader)
        )
        self._find_last_layer(self.data)
        params: torch.Tensor = parameters_to_vector(
            self.model.last_layer.parameters()
        ).detach()
        self.n_params: int = len(params)
        self.n_layers: int = len(list(self.model.last_layer.parameters()))
        # here, check the already set prior precision again
        self.prior_precision: float | torch.Tensor = self._prior_precision
        self.prior_mean: float | torch.Tensor = self._prior_mean
        self._init_H()

    super().fit(train_loader, override=override)
    self.mean: torch.Tensor = parameters_to_vector(
        self.model.last_layer.parameters()
    )

    if not self.enable_backprop:
        self.mean = self.mean.detach()

log_marginal_likelihood #

log_marginal_likelihood(prior_precision: Tensor | None = None, sigma_noise: Tensor | None = None) -> Tensor

Compute the Laplace approximation to the log marginal likelihood subject to specific Hessian approximations that subclasses implement. Requires that the Laplace approximation has been fit before. The resulting torch.Tensor is differentiable in prior_precision and sigma_noise if these have gradients enabled. By passing prior_precision or sigma_noise, the current value is overwritten. This is useful for iterating on the log marginal likelihood.

Parameters:

  • prior_precision #

    (Tensor, default: None ) –

    prior precision if should be changed from current prior_precision value

  • sigma_noise #

    (Tensor, default: None ) –

    observation noise standard deviation if should be changed

Returns:

  • log_marglik ( Tensor ) –
Source code in laplace/baselaplace.py
def log_marginal_likelihood(
    self,
    prior_precision: torch.Tensor | None = None,
    sigma_noise: torch.Tensor | None = None,
) -> torch.Tensor:
    """Compute the Laplace approximation to the log marginal likelihood subject
    to specific Hessian approximations that subclasses implement.
    Requires that the Laplace approximation has been fit before.
    The resulting torch.Tensor is differentiable in `prior_precision` and
    `sigma_noise` if these have gradients enabled.
    By passing `prior_precision` or `sigma_noise`, the current value is
    overwritten. This is useful for iterating on the log marginal likelihood.

    Parameters
    ----------
    prior_precision : torch.Tensor, optional
        prior precision if should be changed from current `prior_precision` value
    sigma_noise : torch.Tensor, optional
        observation noise standard deviation if should be changed

    Returns
    -------
    log_marglik : torch.Tensor
    """
    # update prior precision (useful when iterating on marglik)
    if prior_precision is not None:
        self.prior_precision = prior_precision

    # update sigma_noise (useful when iterating on marglik)
    if sigma_noise is not None:
        if self.likelihood != Likelihood.REGRESSION:
            raise ValueError("Can only change sigma_noise for regression.")

        self.sigma_noise = sigma_noise

    return self.log_likelihood - 0.5 * (self.log_det_ratio + self.scatter)

__call__ #

__call__(x: Tensor | MutableMapping[str, Tensor | Any], pred_type: PredType | str = GLM, joint: bool = False, link_approx: LinkApprox | str = PROBIT, n_samples: int = 100, diagonal_output: bool = False, generator: Generator | None = None, fitting: bool = False, **model_kwargs: dict[str, Any]) -> Tensor | tuple[Tensor, Tensor]

Compute the posterior predictive on input data x.

Parameters:

  • x #

    (Tensor or MutableMapping) –

    (batch_size, input_shape) if tensor. If MutableMapping, must contain the said tensor.

  • pred_type #

    (('glm', 'nn'), default: 'glm' ) –

    type of posterior predictive, linearized GLM predictive or neural network sampling predictive. The GLM predictive is consistent with the curvature approximations used here. When Laplace is done only on subset of parameters (i.e. some grad are disabled), only nn predictive is supported.

  • (('mc', 'probit', 'bridge', 'bridge_norm'), default: 'mc' ) –

    how to approximate the classification link function for the 'glm'. For pred_type='nn', only 'mc' is possible.

  • joint #

    (bool, default: False ) –

    Whether to output a joint predictive distribution in regression with pred_type='glm'. If set to True, the predictive distribution has the same form as GP posterior, i.e. N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]). If False, then only outputs the marginal predictive distribution. Only available for regression and GLM predictive.

  • n_samples #

    (int, default: 100 ) –

    number of samples for link_approx='mc'.

  • diagonal_output #

    (bool, default: False ) –

    whether to use a diagonalized posterior predictive on the outputs. Only works for pred_type='glm' when joint=False in regression. In the case of last-layer Laplace with a diagonal or Kron Hessian, setting this to True makes computation much(!) faster for large number of outputs.

  • generator #

    (Generator, default: None ) –

    random number generator to control the samples (if sampling used).

  • fitting #

    (bool, default: False ) –

    whether or not this predictive call is done during fitting. Only useful for reward modeling: the likelihood is set to "regression" when False and "classification" when True.

Returns:

  • predictive ( Tensor or tuple[Tensor] ) –

    For likelihood='classification', a torch.Tensor is returned with a distribution over classes (similar to a Softmax). For likelihood='regression', a tuple of torch.Tensor is returned with the mean and the predictive variance. For likelihood='regression' and joint=True, a tuple of torch.Tensor is returned with the mean and the predictive covariance.

Source code in laplace/baselaplace.py
def __call__(
    self,
    x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
    pred_type: PredType | str = PredType.GLM,
    joint: bool = False,
    link_approx: LinkApprox | str = LinkApprox.PROBIT,
    n_samples: int = 100,
    diagonal_output: bool = False,
    generator: torch.Generator | None = None,
    fitting: bool = False,
    **model_kwargs: dict[str, Any],
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
    """Compute the posterior predictive on input data `x`.

    Parameters
    ----------
    x : torch.Tensor or MutableMapping
        `(batch_size, input_shape)` if tensor. If MutableMapping, must contain
        the said tensor.

    pred_type : {'glm', 'nn'}, default='glm'
        type of posterior predictive, linearized GLM predictive or neural
        network sampling predictive. The GLM predictive is consistent with
        the curvature approximations used here. When Laplace is done only
        on subset of parameters (i.e. some grad are disabled),
        only `nn` predictive is supported.

    link_approx : {'mc', 'probit', 'bridge', 'bridge_norm'}
        how to approximate the classification link function for the `'glm'`.
        For `pred_type='nn'`, only 'mc' is possible.

    joint : bool
        Whether to output a joint predictive distribution in regression with
        `pred_type='glm'`. If set to `True`, the predictive distribution
        has the same form as GP posterior, i.e. N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]).
        If `False`, then only outputs the marginal predictive distribution.
        Only available for regression and GLM predictive.

    n_samples : int
        number of samples for `link_approx='mc'`.

    diagonal_output : bool
        whether to use a diagonalized posterior predictive on the outputs.
        Only works for `pred_type='glm'` when `joint=False` in regression.
        In the case of last-layer Laplace with a diagonal or Kron Hessian,
        setting this to `True` makes computation much(!) faster for large
        number of outputs.

    generator : torch.Generator, optional
        random number generator to control the samples (if sampling used).

    fitting : bool, default=False
        whether or not this predictive call is done during fitting. Only useful for
        reward modeling: the likelihood is set to `"regression"` when `False` and
        `"classification"` when `True`.

    Returns
    -------
    predictive: torch.Tensor or tuple[torch.Tensor]
        For `likelihood='classification'`, a torch.Tensor is returned with
        a distribution over classes (similar to a Softmax).
        For `likelihood='regression'`, a tuple of torch.Tensor is returned
        with the mean and the predictive variance.
        For `likelihood='regression'` and `joint=True`, a tuple of torch.Tensor
        is returned with the mean and the predictive covariance.
    """
    if pred_type not in [pred for pred in PredType]:
        raise ValueError("Only glm and nn supported as prediction types.")

    if link_approx not in [la for la in LinkApprox]:
        raise ValueError(f"Unsupported link approximation {link_approx}.")

    if pred_type == PredType.NN and link_approx != LinkApprox.MC:
        raise ValueError(
            "Only mc link approximation is supported for nn prediction type."
        )

    if generator is not None:
        if (
            not isinstance(generator, torch.Generator)
            or generator.device != self._device
        ):
            raise ValueError("Invalid random generator (check type and device).")

    likelihood = self.likelihood
    if likelihood == Likelihood.REWARD_MODELING:
        likelihood = Likelihood.CLASSIFICATION if fitting else Likelihood.REGRESSION

    if pred_type == PredType.GLM:
        return self._glm_forward_call(
            x, likelihood, joint, link_approx, n_samples, diagonal_output
        )
    else:
        if likelihood == Likelihood.REGRESSION:
            samples = self._nn_predictive_samples(x, n_samples, **model_kwargs)
            return samples.mean(dim=0), samples.var(dim=0)
        else:  # classification; the average is computed online
            return self._nn_predictive_classification(x, n_samples, **model_kwargs)

_glm_forward_call #

_glm_forward_call(x: Tensor | MutableMapping, likelihood: Likelihood | str, joint: bool = False, link_approx: LinkApprox | str = PROBIT, n_samples: int = 100, diagonal_output: bool = False) -> Tensor | tuple[Tensor, Tensor]

Compute the posterior predictive on input data x for "glm" pred type.

Parameters:

  • x #

    (Tensor or MutableMapping) –

    (batch_size, input_shape) if tensor. If MutableMapping, must contain the said tensor.

  • likelihood #

    (Likelihood or str in {'classification', 'regression', 'reward_modeling'}) –

    determines the log likelihood Hessian approximation.

  • (('mc', 'probit', 'bridge', 'bridge_norm'), default: 'mc' ) –

    how to approximate the classification link function for the 'glm'. For pred_type='nn', only 'mc' is possible.

  • joint #

    (bool, default: False ) –

    Whether to output a joint predictive distribution in regression with pred_type='glm'. If set to True, the predictive distribution has the same form as GP posterior, i.e. N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]). If False, then only outputs the marginal predictive distribution. Only available for regression and GLM predictive.

  • n_samples #

    (int, default: 100 ) –

    number of samples for link_approx='mc'.

  • diagonal_output #

    (bool, default: False ) –

    whether to use a diagonalized posterior predictive on the outputs. Only works for pred_type='glm' and link_approx='mc'.

Returns:

  • predictive ( Tensor or tuple[Tensor] ) –

    For likelihood='classification', a torch.Tensor is returned with a distribution over classes (similar to a Softmax). For likelihood='regression', a tuple of torch.Tensor is returned with the mean and the predictive variance. For likelihood='regression' and joint=True, a tuple of torch.Tensor is returned with the mean and the predictive covariance.

Source code in laplace/baselaplace.py
def _glm_forward_call(
    self,
    x: torch.Tensor | MutableMapping,
    likelihood: Likelihood | str,
    joint: bool = False,
    link_approx: LinkApprox | str = LinkApprox.PROBIT,
    n_samples: int = 100,
    diagonal_output: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
    """Compute the posterior predictive on input data `x` for "glm" pred type.

    Parameters
    ----------
    x : torch.Tensor or MutableMapping
        `(batch_size, input_shape)` if tensor. If MutableMapping, must contain
        the said tensor.

    likelihood : Likelihood or str in {'classification', 'regression', 'reward_modeling'}
        determines the log likelihood Hessian approximation.

    link_approx : {'mc', 'probit', 'bridge', 'bridge_norm'}
        how to approximate the classification link function for the `'glm'`.
        For `pred_type='nn'`, only 'mc' is possible.

    joint : bool
        Whether to output a joint predictive distribution in regression with
        `pred_type='glm'`. If set to `True`, the predictive distribution
        has the same form as GP posterior, i.e. N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]).
        If `False`, then only outputs the marginal predictive distribution.
        Only available for regression and GLM predictive.

    n_samples : int
        number of samples for `link_approx='mc'`.

    diagonal_output : bool
        whether to use a diagonalized posterior predictive on the outputs.
        Only works for `pred_type='glm'` and `link_approx='mc'`.

    Returns
    -------
    predictive: torch.Tensor or tuple[torch.Tensor]
        For `likelihood='classification'`, a torch.Tensor is returned with
        a distribution over classes (similar to a Softmax).
        For `likelihood='regression'`, a tuple of torch.Tensor is returned
        with the mean and the predictive variance.
        For `likelihood='regression'` and `joint=True`, a tuple of torch.Tensor
        is returned with the mean and the predictive covariance.
    """
    f_mu, f_var = self._glm_predictive_distribution(
        x, joint=joint and likelihood == Likelihood.REGRESSION
    )

    if likelihood == Likelihood.REGRESSION:
        if diagonal_output and not joint:
            f_var = torch.diagonal(f_var, dim1=-2, dim2=-1)
        return f_mu, f_var

    if link_approx == LinkApprox.MC:
        return self._glm_predictive_samples(
            f_mu,
            f_var,
            n_samples=n_samples,
            diagonal_output=diagonal_output,
        ).mean(dim=0)
    elif link_approx == LinkApprox.PROBIT:
        kappa = 1 / torch.sqrt(1.0 + np.pi / 8 * f_var.diagonal(dim1=1, dim2=2))
        return torch.softmax(kappa * f_mu, dim=-1)
    elif "bridge" in link_approx:
        # zero mean correction
        f_mu -= (
            f_var.sum(-1)
            * f_mu.sum(-1).reshape(-1, 1)
            / f_var.sum(dim=(1, 2)).reshape(-1, 1)
        )
        f_var -= torch.einsum(
            "bi,bj->bij", f_var.sum(-1), f_var.sum(-2)
        ) / f_var.sum(dim=(1, 2)).reshape(-1, 1, 1)

        # Laplace Bridge
        _, K = f_mu.size(0), f_mu.size(-1)
        f_var_diag = torch.diagonal(f_var, dim1=1, dim2=2)

        # optional: variance correction
        if link_approx == LinkApprox.BRIDGE_NORM:
            f_var_diag_mean = f_var_diag.mean(dim=1)
            f_var_diag_mean /= torch.as_tensor([K / 2], device=self._device).sqrt()
            f_mu /= f_var_diag_mean.sqrt().unsqueeze(-1)
            f_var_diag /= f_var_diag_mean.unsqueeze(-1)

        sum_exp = torch.exp(-f_mu).sum(dim=1).unsqueeze(-1)
        alpha = (1 - 2 / K + f_mu.exp() / K**2 * sum_exp) / f_var_diag
        return torch.nan_to_num(alpha / alpha.sum(dim=1).unsqueeze(-1), nan=1.0)
    else:
        raise ValueError(
            "Prediction path invalid. Check the likelihood, pred_type, link_approx combination!"
        )

_glm_predictive_samples #

_glm_predictive_samples(f_mu: Tensor, f_var: Tensor, n_samples: int, diagonal_output: bool = False, generator: Generator | None = None) -> Tensor

Sample from the posterior predictive on input data x using "glm" prediction type.

Parameters:

  • f_mu #

    (Tensor or MutableMapping) –

    glm predictive mean (batch_size, output_shape)

  • f_var #

    (Tensor or MutableMapping) –

    glm predictive covariances (batch_size, output_shape, output_shape)

  • n_samples #

    (int) –

    number of samples

  • diagonal_output #

    (bool, default: False ) –

    whether to use a diagonalized glm posterior predictive on the outputs.

  • generator #

    (Generator, default: None ) –

    random number generator to control the samples (if sampling used)

Returns:

  • samples ( Tensor ) –

    samples (n_samples, batch_size, output_shape)

Source code in laplace/baselaplace.py
def _glm_predictive_samples(
    self,
    f_mu: torch.Tensor,
    f_var: torch.Tensor,
    n_samples: int,
    diagonal_output: bool = False,
    generator: torch.Generator | None = None,
) -> torch.Tensor:
    """Sample from the posterior predictive on input data `x` using "glm" prediction
    type.

    Parameters
    ----------
    f_mu : torch.Tensor or MutableMapping
        glm predictive mean `(batch_size, output_shape)`

    f_var : torch.Tensor or MutableMapping
        glm predictive covariances `(batch_size, output_shape, output_shape)`

    n_samples : int
        number of samples

    diagonal_output : bool
        whether to use a diagonalized glm posterior predictive on the outputs.

    generator : torch.Generator, optional
        random number generator to control the samples (if sampling used)

    Returns
    -------
    samples : torch.Tensor
        samples `(n_samples, batch_size, output_shape)`
    """
    assert f_var.shape == torch.Size([f_mu.shape[0], f_mu.shape[1], f_mu.shape[1]])

    if diagonal_output:
        f_var = torch.diagonal(f_var, dim1=1, dim2=2)

    f_samples = normal_samples(f_mu, f_var, n_samples, generator)

    if self.likelihood == Likelihood.REGRESSION:
        return f_samples
    else:
        return torch.softmax(f_samples, dim=-1)

log_prob #

log_prob(value: Tensor, normalized: bool = True) -> Tensor

Compute the log probability under the (current) Laplace approximation.

Parameters:

  • value #

    (Tensor) –
  • normalized #

    (bool, default: True ) –

    whether to return log of a properly normalized Gaussian or just the terms that depend on value.

Returns:

  • log_prob ( Tensor ) –
Source code in laplace/baselaplace.py
def log_prob(self, value: torch.Tensor, normalized: bool = True) -> torch.Tensor:
    """Compute the log probability under the (current) Laplace approximation.

    Parameters
    ----------
    value: torch.Tensor
    normalized : bool, default=True
        whether to return log of a properly normalized Gaussian or just the
        terms that depend on `value`.

    Returns
    -------
    log_prob : torch.Tensor
    """
    if not normalized:
        return -self.square_norm(value) / 2
    log_prob = (
        -self.n_params / 2 * log(2 * pi) + self.log_det_posterior_precision / 2
    )
    log_prob -= self.square_norm(value) / 2
    return log_prob

predictive_samples #

predictive_samples(x: Tensor | MutableMapping[str, Tensor | Any], pred_type: PredType | str = GLM, n_samples: int = 100, diagonal_output: bool = False, generator: Generator | None = None) -> Tensor

Sample from the posterior predictive on input data x. Can be used, for example, for Thompson sampling.

Parameters:

  • x #

    (Tensor or MutableMapping) –

    input data (batch_size, input_shape)

  • pred_type #

    (('glm', 'nn'), default: 'glm' ) –

    type of posterior predictive, linearized GLM predictive or neural network sampling predictive. The GLM predictive is consistent with the curvature approximations used here.

  • n_samples #

    (int, default: 100 ) –

    number of samples

  • diagonal_output #

    (bool, default: False ) –

    whether to use a diagonalized glm posterior predictive on the outputs. Only applies when pred_type='glm'.

  • generator #

    (Generator, default: None ) –

    random number generator to control the samples (if sampling used)

Returns:

  • samples ( Tensor ) –

    samples (n_samples, batch_size, output_shape)

Source code in laplace/baselaplace.py
def predictive_samples(
    self,
    x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
    pred_type: PredType | str = PredType.GLM,
    n_samples: int = 100,
    diagonal_output: bool = False,
    generator: torch.Generator | None = None,
) -> torch.Tensor:
    """Sample from the posterior predictive on input data `x`.
    Can be used, for example, for Thompson sampling.

    Parameters
    ----------
    x : torch.Tensor or MutableMapping
        input data `(batch_size, input_shape)`

    pred_type : {'glm', 'nn'}, default='glm'
        type of posterior predictive, linearized GLM predictive or neural
        network sampling predictive. The GLM predictive is consistent with
        the curvature approximations used here.

    n_samples : int
        number of samples

    diagonal_output : bool
        whether to use a diagonalized glm posterior predictive on the outputs.
        Only applies when `pred_type='glm'`.

    generator : torch.Generator, optional
        random number generator to control the samples (if sampling used)

    Returns
    -------
    samples : torch.Tensor
        samples `(n_samples, batch_size, output_shape)`
    """
    if pred_type not in PredType.__members__.values():
        raise ValueError("Only glm and nn supported as prediction types.")

    if pred_type == PredType.GLM:
        f_mu, f_var = self._glm_predictive_distribution(x)
        return self._glm_predictive_samples(
            f_mu, f_var, n_samples, diagonal_output, generator
        )

    else:  # 'nn'
        return self._nn_predictive_samples(x, n_samples, generator)

FullLLLaplace #

FullLLLaplace(model: Module, likelihood: Likelihood | str, sigma_noise: float | Tensor = 1.0, prior_precision: float | Tensor = 1.0, prior_mean: float | Tensor = 0.0, temperature: float = 1.0, enable_backprop: bool = False, feature_reduction: FeatureReduction | str | None = None, dict_key_x: str = 'input_ids', dict_key_y: str = 'labels', backend: type[CurvatureInterface] | None = None, last_layer_name: str | None = None, backend_kwargs: dict[str, Any] | None = None, asdl_fisher_kwargs: dict[str, Any] | None = None)

Bases: LLLaplace, FullLaplace

Last-layer Laplace approximation with full, i.e., dense, log likelihood Hessian approximation and hence posterior precision. Based on the chosen backend parameter, the full approximation can be, for example, a generalized Gauss-Newton matrix. Mathematically, we have \(P \in \mathbb{R}^{P \times P}\). See FullLaplace, LLLaplace, and BaseLaplace for the full interface.

Methods:

  • fit

    Fit the local Laplace approximation at the parameters of the model.

  • log_marginal_likelihood

    Compute the Laplace approximation to the log marginal likelihood subject

  • __call__

    Compute the posterior predictive on input data x.

  • log_prob

    Compute the log probability under the (current) Laplace approximation.

  • predictive_samples

    Sample from the posterior predictive on input data x.

  • functional_variance_fast

    Should be overriden if there exists a trick to make this fast!

Attributes:

  • log_likelihood (Tensor) –

    Compute log likelihood on the training data after .fit() has been called.

  • prior_precision_diag (Tensor) –

    Obtain the diagonal prior precision \(p_0\) constructed from either

  • scatter (Tensor) –

    Computes the scatter, a term of the log marginal likelihood that

  • log_det_prior_precision (Tensor) –

    Compute log determinant of the prior precision

  • log_det_ratio (Tensor) –

    Compute the log determinant ratio, a part of the log marginal likelihood.

  • posterior_precision (Tensor) –

    Posterior precision \(P\).

  • posterior_scale (Tensor) –

    Posterior scale (square root of the covariance), i.e.,

  • posterior_covariance (Tensor) –

    Posterior covariance, i.e., \(P^{-1}\).

Source code in laplace/lllaplace.py
def __init__(
    self,
    model: nn.Module,
    likelihood: Likelihood | str,
    sigma_noise: float | torch.Tensor = 1.0,
    prior_precision: float | torch.Tensor = 1.0,
    prior_mean: float | torch.Tensor = 0.0,
    temperature: float = 1.0,
    enable_backprop: bool = False,
    feature_reduction: FeatureReduction | str | None = None,
    dict_key_x: str = "input_ids",
    dict_key_y: str = "labels",
    backend: type[CurvatureInterface] | None = None,
    last_layer_name: str | None = None,
    backend_kwargs: dict[str, Any] | None = None,
    asdl_fisher_kwargs: dict[str, Any] | None = None,
):
    if asdl_fisher_kwargs is not None:
        raise ValueError("Last-layer Laplace does not support asdl_fisher_kwargs.")

    self.H = None
    super().__init__(
        model,
        likelihood,
        sigma_noise=sigma_noise,
        prior_precision=1.0,
        prior_mean=0.0,
        temperature=temperature,
        enable_backprop=enable_backprop,
        dict_key_x=dict_key_x,
        dict_key_y=dict_key_y,
        backend=backend,
        backend_kwargs=backend_kwargs,
    )
    self.model = FeatureExtractor(
        deepcopy(model),
        last_layer_name=last_layer_name,
        enable_backprop=enable_backprop,
        feature_reduction=feature_reduction,
    )

    if self.model.last_layer is None:
        self.mean: torch.Tensor | None = None
        self.n_params: int | None = None
        self.n_layers: int | None = None
        # ignore checks of prior mean setter temporarily, check on .fit()
        self._prior_precision: float | torch.Tensor = prior_precision
        self._prior_mean: float | torch.Tensor = prior_mean
    else:
        self.n_params: int = len(
            parameters_to_vector(self.model.last_layer.parameters())
        )
        self.n_layers: int | None = len(list(self.model.last_layer.parameters()))
        self.prior_precision: float | torch.Tensor = prior_precision
        self.prior_mean: float | torch.Tensor = prior_mean
        self.mean: float | torch.Tensor = self.prior_mean
        self._init_H()

    self._backend_kwargs["last_layer"] = True
    self._last_layer_name: str | None = last_layer_name

log_likelihood #

log_likelihood: Tensor

Compute log likelihood on the training data after .fit() has been called. The log likelihood is computed on-demand based on the loss and, for example, the observation noise which makes it differentiable in the latter for iterative updates.

Returns:

  • log_likelihood ( Tensor ) –

prior_precision_diag #

prior_precision_diag: Tensor

Obtain the diagonal prior precision \(p_0\) constructed from either a scalar or diagonal prior precision.

Returns:

  • prior_precision_diag ( Tensor ) –

scatter #

scatter: Tensor

Computes the scatter, a term of the log marginal likelihood that corresponds to L-2 regularization: scatter = \((\theta_{MAP} - \mu_0)^{T} P_0 (\theta_{MAP} - \mu_0) \).

Returns:

  • scatter ( Tensor ) –

log_det_prior_precision #

log_det_prior_precision: Tensor

Compute log determinant of the prior precision \(\log \det P_0\)

Returns:

  • log_det ( Tensor ) –

log_det_ratio #

log_det_ratio: Tensor

Compute the log determinant ratio, a part of the log marginal likelihood.

\[ \log \frac{\det P}{\det P_0} = \log \det P - \log \det P_0 \]

Returns:

  • log_det_ratio ( Tensor ) –

posterior_precision #

posterior_precision: Tensor

Posterior precision \(P\).

Returns:

  • precision ( tensor ) –

    (parameters, parameters)

posterior_scale #

posterior_scale: Tensor

Posterior scale (square root of the covariance), i.e., \(P^{-\frac{1}{2}}\).

Returns:

  • scale ( tensor ) –

    (parameters, parameters)

posterior_covariance #

posterior_covariance: Tensor

Posterior covariance, i.e., \(P^{-1}\).

Returns:

  • covariance ( tensor ) –

    (parameters, parameters)

fit #

fit(train_loader: DataLoader, override: bool = True, progress_bar: bool = False) -> None

Fit the local Laplace approximation at the parameters of the model.

Parameters:

  • train_loader #

    (DataLoader) –

    each iterate is a training batch, either (X, y) tensors or a dict-like object containing keys as expressed by self.dict_key_x and self.dict_key_y. train_loader.dataset needs to be set to access \(N\), size of the data set.

  • override #

    (bool, default: True ) –

    whether to initialize H, loss, and n_data again; setting to False is useful for online learning settings to accumulate a sequential posterior approximation.

  • progress_bar #

    (bool, default: False ) –
Source code in laplace/lllaplace.py
def fit(
    self,
    train_loader: DataLoader,
    override: bool = True,
    progress_bar: bool = False,
) -> None:
    """Fit the local Laplace approximation at the parameters of the model.

    Parameters
    ----------
    train_loader : torch.data.utils.DataLoader
        each iterate is a training batch, either `(X, y)` tensors or a dict-like
        object containing keys as expressed by `self.dict_key_x` and
        `self.dict_key_y`. `train_loader.dataset` needs to be set to access
        \\(N\\), size of the data set.
    override : bool, default=True
        whether to initialize H, loss, and n_data again; setting to False is useful for
        online learning settings to accumulate a sequential posterior approximation.
    progress_bar: bool, default=False
    """
    if not override:
        raise ValueError(
            "Last-layer Laplace approximations do not support `override=False`."
        )

    self.model.eval()

    if self.model.last_layer is None:
        self.data: tuple[torch.Tensor, torch.Tensor] | MutableMapping = next(
            iter(train_loader)
        )
        self._find_last_layer(self.data)
        params: torch.Tensor = parameters_to_vector(
            self.model.last_layer.parameters()
        ).detach()
        self.n_params: int = len(params)
        self.n_layers: int = len(list(self.model.last_layer.parameters()))
        # here, check the already set prior precision again
        self.prior_precision: float | torch.Tensor = self._prior_precision
        self.prior_mean: float | torch.Tensor = self._prior_mean
        self._init_H()

    super().fit(train_loader, override=override)
    self.mean: torch.Tensor = parameters_to_vector(
        self.model.last_layer.parameters()
    )

    if not self.enable_backprop:
        self.mean = self.mean.detach()

log_marginal_likelihood #

log_marginal_likelihood(prior_precision: Tensor | None = None, sigma_noise: Tensor | None = None) -> Tensor

Compute the Laplace approximation to the log marginal likelihood subject to specific Hessian approximations that subclasses implement. Requires that the Laplace approximation has been fit before. The resulting torch.Tensor is differentiable in prior_precision and sigma_noise if these have gradients enabled. By passing prior_precision or sigma_noise, the current value is overwritten. This is useful for iterating on the log marginal likelihood.

Parameters:

  • prior_precision #

    (Tensor, default: None ) –

    prior precision if should be changed from current prior_precision value

  • sigma_noise #

    (Tensor, default: None ) –

    observation noise standard deviation if should be changed

Returns:

  • log_marglik ( Tensor ) –
Source code in laplace/baselaplace.py
def log_marginal_likelihood(
    self,
    prior_precision: torch.Tensor | None = None,
    sigma_noise: torch.Tensor | None = None,
) -> torch.Tensor:
    """Compute the Laplace approximation to the log marginal likelihood subject
    to specific Hessian approximations that subclasses implement.
    Requires that the Laplace approximation has been fit before.
    The resulting torch.Tensor is differentiable in `prior_precision` and
    `sigma_noise` if these have gradients enabled.
    By passing `prior_precision` or `sigma_noise`, the current value is
    overwritten. This is useful for iterating on the log marginal likelihood.

    Parameters
    ----------
    prior_precision : torch.Tensor, optional
        prior precision if should be changed from current `prior_precision` value
    sigma_noise : torch.Tensor, optional
        observation noise standard deviation if should be changed

    Returns
    -------
    log_marglik : torch.Tensor
    """
    # update prior precision (useful when iterating on marglik)
    if prior_precision is not None:
        self.prior_precision = prior_precision

    # update sigma_noise (useful when iterating on marglik)
    if sigma_noise is not None:
        if self.likelihood != Likelihood.REGRESSION:
            raise ValueError("Can only change sigma_noise for regression.")

        self.sigma_noise = sigma_noise

    return self.log_likelihood - 0.5 * (self.log_det_ratio + self.scatter)

__call__ #

__call__(x: Tensor | MutableMapping[str, Tensor | Any], pred_type: PredType | str = GLM, joint: bool = False, link_approx: LinkApprox | str = PROBIT, n_samples: int = 100, diagonal_output: bool = False, generator: Generator | None = None, fitting: bool = False, **model_kwargs: dict[str, Any]) -> Tensor | tuple[Tensor, Tensor]

Compute the posterior predictive on input data x.

Parameters:

  • x #

    (Tensor or MutableMapping) –

    (batch_size, input_shape) if tensor. If MutableMapping, must contain the said tensor.

  • pred_type #

    (('glm', 'nn'), default: 'glm' ) –

    type of posterior predictive, linearized GLM predictive or neural network sampling predictive. The GLM predictive is consistent with the curvature approximations used here. When Laplace is done only on subset of parameters (i.e. some grad are disabled), only nn predictive is supported.

  • (('mc', 'probit', 'bridge', 'bridge_norm'), default: 'mc' ) –

    how to approximate the classification link function for the 'glm'. For pred_type='nn', only 'mc' is possible.

  • joint #

    (bool, default: False ) –

    Whether to output a joint predictive distribution in regression with pred_type='glm'. If set to True, the predictive distribution has the same form as GP posterior, i.e. N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]). If False, then only outputs the marginal predictive distribution. Only available for regression and GLM predictive.

  • n_samples #

    (int, default: 100 ) –

    number of samples for link_approx='mc'.

  • diagonal_output #

    (bool, default: False ) –

    whether to use a diagonalized posterior predictive on the outputs. Only works for pred_type='glm' when joint=False in regression. In the case of last-layer Laplace with a diagonal or Kron Hessian, setting this to True makes computation much(!) faster for large number of outputs.

  • generator #

    (Generator, default: None ) –

    random number generator to control the samples (if sampling used).

  • fitting #

    (bool, default: False ) –

    whether or not this predictive call is done during fitting. Only useful for reward modeling: the likelihood is set to "regression" when False and "classification" when True.

Returns:

  • predictive ( Tensor or tuple[Tensor] ) –

    For likelihood='classification', a torch.Tensor is returned with a distribution over classes (similar to a Softmax). For likelihood='regression', a tuple of torch.Tensor is returned with the mean and the predictive variance. For likelihood='regression' and joint=True, a tuple of torch.Tensor is returned with the mean and the predictive covariance.

Source code in laplace/baselaplace.py
def __call__(
    self,
    x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
    pred_type: PredType | str = PredType.GLM,
    joint: bool = False,
    link_approx: LinkApprox | str = LinkApprox.PROBIT,
    n_samples: int = 100,
    diagonal_output: bool = False,
    generator: torch.Generator | None = None,
    fitting: bool = False,
    **model_kwargs: dict[str, Any],
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
    """Compute the posterior predictive on input data `x`.

    Parameters
    ----------
    x : torch.Tensor or MutableMapping
        `(batch_size, input_shape)` if tensor. If MutableMapping, must contain
        the said tensor.

    pred_type : {'glm', 'nn'}, default='glm'
        type of posterior predictive, linearized GLM predictive or neural
        network sampling predictive. The GLM predictive is consistent with
        the curvature approximations used here. When Laplace is done only
        on subset of parameters (i.e. some grad are disabled),
        only `nn` predictive is supported.

    link_approx : {'mc', 'probit', 'bridge', 'bridge_norm'}
        how to approximate the classification link function for the `'glm'`.
        For `pred_type='nn'`, only 'mc' is possible.

    joint : bool
        Whether to output a joint predictive distribution in regression with
        `pred_type='glm'`. If set to `True`, the predictive distribution
        has the same form as GP posterior, i.e. N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]).
        If `False`, then only outputs the marginal predictive distribution.
        Only available for regression and GLM predictive.

    n_samples : int
        number of samples for `link_approx='mc'`.

    diagonal_output : bool
        whether to use a diagonalized posterior predictive on the outputs.
        Only works for `pred_type='glm'` when `joint=False` in regression.
        In the case of last-layer Laplace with a diagonal or Kron Hessian,
        setting this to `True` makes computation much(!) faster for large
        number of outputs.

    generator : torch.Generator, optional
        random number generator to control the samples (if sampling used).

    fitting : bool, default=False
        whether or not this predictive call is done during fitting. Only useful for
        reward modeling: the likelihood is set to `"regression"` when `False` and
        `"classification"` when `True`.

    Returns
    -------
    predictive: torch.Tensor or tuple[torch.Tensor]
        For `likelihood='classification'`, a torch.Tensor is returned with
        a distribution over classes (similar to a Softmax).
        For `likelihood='regression'`, a tuple of torch.Tensor is returned
        with the mean and the predictive variance.
        For `likelihood='regression'` and `joint=True`, a tuple of torch.Tensor
        is returned with the mean and the predictive covariance.
    """
    if pred_type not in [pred for pred in PredType]:
        raise ValueError("Only glm and nn supported as prediction types.")

    if link_approx not in [la for la in LinkApprox]:
        raise ValueError(f"Unsupported link approximation {link_approx}.")

    if pred_type == PredType.NN and link_approx != LinkApprox.MC:
        raise ValueError(
            "Only mc link approximation is supported for nn prediction type."
        )

    if generator is not None:
        if (
            not isinstance(generator, torch.Generator)
            or generator.device != self._device
        ):
            raise ValueError("Invalid random generator (check type and device).")

    likelihood = self.likelihood
    if likelihood == Likelihood.REWARD_MODELING:
        likelihood = Likelihood.CLASSIFICATION if fitting else Likelihood.REGRESSION

    if pred_type == PredType.GLM:
        return self._glm_forward_call(
            x, likelihood, joint, link_approx, n_samples, diagonal_output
        )
    else:
        if likelihood == Likelihood.REGRESSION:
            samples = self._nn_predictive_samples(x, n_samples, **model_kwargs)
            return samples.mean(dim=0), samples.var(dim=0)
        else:  # classification; the average is computed online
            return self._nn_predictive_classification(x, n_samples, **model_kwargs)

_glm_forward_call #

_glm_forward_call(x: Tensor | MutableMapping, likelihood: Likelihood | str, joint: bool = False, link_approx: LinkApprox | str = PROBIT, n_samples: int = 100, diagonal_output: bool = False) -> Tensor | tuple[Tensor, Tensor]

Compute the posterior predictive on input data x for "glm" pred type.

Parameters:

  • x #

    (Tensor or MutableMapping) –

    (batch_size, input_shape) if tensor. If MutableMapping, must contain the said tensor.

  • likelihood #

    (Likelihood or str in {'classification', 'regression', 'reward_modeling'}) –

    determines the log likelihood Hessian approximation.

  • (('mc', 'probit', 'bridge', 'bridge_norm'), default: 'mc' ) –

    how to approximate the classification link function for the 'glm'. For pred_type='nn', only 'mc' is possible.

  • joint #

    (bool, default: False ) –

    Whether to output a joint predictive distribution in regression with pred_type='glm'. If set to True, the predictive distribution has the same form as GP posterior, i.e. N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]). If False, then only outputs the marginal predictive distribution. Only available for regression and GLM predictive.

  • n_samples #

    (int, default: 100 ) –

    number of samples for link_approx='mc'.

  • diagonal_output #

    (bool, default: False ) –

    whether to use a diagonalized posterior predictive on the outputs. Only works for pred_type='glm' and link_approx='mc'.

Returns:

  • predictive ( Tensor or tuple[Tensor] ) –

    For likelihood='classification', a torch.Tensor is returned with a distribution over classes (similar to a Softmax). For likelihood='regression', a tuple of torch.Tensor is returned with the mean and the predictive variance. For likelihood='regression' and joint=True, a tuple of torch.Tensor is returned with the mean and the predictive covariance.

Source code in laplace/baselaplace.py
def _glm_forward_call(
    self,
    x: torch.Tensor | MutableMapping,
    likelihood: Likelihood | str,
    joint: bool = False,
    link_approx: LinkApprox | str = LinkApprox.PROBIT,
    n_samples: int = 100,
    diagonal_output: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
    """Compute the posterior predictive on input data `x` for "glm" pred type.

    Parameters
    ----------
    x : torch.Tensor or MutableMapping
        `(batch_size, input_shape)` if tensor. If MutableMapping, must contain
        the said tensor.

    likelihood : Likelihood or str in {'classification', 'regression', 'reward_modeling'}
        determines the log likelihood Hessian approximation.

    link_approx : {'mc', 'probit', 'bridge', 'bridge_norm'}
        how to approximate the classification link function for the `'glm'`.
        For `pred_type='nn'`, only 'mc' is possible.

    joint : bool
        Whether to output a joint predictive distribution in regression with
        `pred_type='glm'`. If set to `True`, the predictive distribution
        has the same form as GP posterior, i.e. N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]).
        If `False`, then only outputs the marginal predictive distribution.
        Only available for regression and GLM predictive.

    n_samples : int
        number of samples for `link_approx='mc'`.

    diagonal_output : bool
        whether to use a diagonalized posterior predictive on the outputs.
        Only works for `pred_type='glm'` and `link_approx='mc'`.

    Returns
    -------
    predictive: torch.Tensor or tuple[torch.Tensor]
        For `likelihood='classification'`, a torch.Tensor is returned with
        a distribution over classes (similar to a Softmax).
        For `likelihood='regression'`, a tuple of torch.Tensor is returned
        with the mean and the predictive variance.
        For `likelihood='regression'` and `joint=True`, a tuple of torch.Tensor
        is returned with the mean and the predictive covariance.
    """
    f_mu, f_var = self._glm_predictive_distribution(
        x, joint=joint and likelihood == Likelihood.REGRESSION
    )

    if likelihood == Likelihood.REGRESSION:
        if diagonal_output and not joint:
            f_var = torch.diagonal(f_var, dim1=-2, dim2=-1)
        return f_mu, f_var

    if link_approx == LinkApprox.MC:
        return self._glm_predictive_samples(
            f_mu,
            f_var,
            n_samples=n_samples,
            diagonal_output=diagonal_output,
        ).mean(dim=0)
    elif link_approx == LinkApprox.PROBIT:
        kappa = 1 / torch.sqrt(1.0 + np.pi / 8 * f_var.diagonal(dim1=1, dim2=2))
        return torch.softmax(kappa * f_mu, dim=-1)
    elif "bridge" in link_approx:
        # zero mean correction
        f_mu -= (
            f_var.sum(-1)
            * f_mu.sum(-1).reshape(-1, 1)
            / f_var.sum(dim=(1, 2)).reshape(-1, 1)
        )
        f_var -= torch.einsum(
            "bi,bj->bij", f_var.sum(-1), f_var.sum(-2)
        ) / f_var.sum(dim=(1, 2)).reshape(-1, 1, 1)

        # Laplace Bridge
        _, K = f_mu.size(0), f_mu.size(-1)
        f_var_diag = torch.diagonal(f_var, dim1=1, dim2=2)

        # optional: variance correction
        if link_approx == LinkApprox.BRIDGE_NORM:
            f_var_diag_mean = f_var_diag.mean(dim=1)
            f_var_diag_mean /= torch.as_tensor([K / 2], device=self._device).sqrt()
            f_mu /= f_var_diag_mean.sqrt().unsqueeze(-1)
            f_var_diag /= f_var_diag_mean.unsqueeze(-1)

        sum_exp = torch.exp(-f_mu).sum(dim=1).unsqueeze(-1)
        alpha = (1 - 2 / K + f_mu.exp() / K**2 * sum_exp) / f_var_diag
        return torch.nan_to_num(alpha / alpha.sum(dim=1).unsqueeze(-1), nan=1.0)
    else:
        raise ValueError(
            "Prediction path invalid. Check the likelihood, pred_type, link_approx combination!"
        )

_glm_predictive_samples #

_glm_predictive_samples(f_mu: Tensor, f_var: Tensor, n_samples: int, diagonal_output: bool = False, generator: Generator | None = None) -> Tensor

Sample from the posterior predictive on input data x using "glm" prediction type.

Parameters:

  • f_mu #

    (Tensor or MutableMapping) –

    glm predictive mean (batch_size, output_shape)

  • f_var #

    (Tensor or MutableMapping) –

    glm predictive covariances (batch_size, output_shape, output_shape)

  • n_samples #

    (int) –

    number of samples

  • diagonal_output #

    (bool, default: False ) –

    whether to use a diagonalized glm posterior predictive on the outputs.

  • generator #

    (Generator, default: None ) –

    random number generator to control the samples (if sampling used)

Returns:

  • samples ( Tensor ) –

    samples (n_samples, batch_size, output_shape)

Source code in laplace/baselaplace.py
def _glm_predictive_samples(
    self,
    f_mu: torch.Tensor,
    f_var: torch.Tensor,
    n_samples: int,
    diagonal_output: bool = False,
    generator: torch.Generator | None = None,
) -> torch.Tensor:
    """Sample from the posterior predictive on input data `x` using "glm" prediction
    type.

    Parameters
    ----------
    f_mu : torch.Tensor or MutableMapping
        glm predictive mean `(batch_size, output_shape)`

    f_var : torch.Tensor or MutableMapping
        glm predictive covariances `(batch_size, output_shape, output_shape)`

    n_samples : int
        number of samples

    diagonal_output : bool
        whether to use a diagonalized glm posterior predictive on the outputs.

    generator : torch.Generator, optional
        random number generator to control the samples (if sampling used)

    Returns
    -------
    samples : torch.Tensor
        samples `(n_samples, batch_size, output_shape)`
    """
    assert f_var.shape == torch.Size([f_mu.shape[0], f_mu.shape[1], f_mu.shape[1]])

    if diagonal_output:
        f_var = torch.diagonal(f_var, dim1=1, dim2=2)

    f_samples = normal_samples(f_mu, f_var, n_samples, generator)

    if self.likelihood == Likelihood.REGRESSION:
        return f_samples
    else:
        return torch.softmax(f_samples, dim=-1)

log_prob #

log_prob(value: Tensor, normalized: bool = True) -> Tensor

Compute the log probability under the (current) Laplace approximation.

Parameters:

  • value #

    (Tensor) –
  • normalized #

    (bool, default: True ) –

    whether to return log of a properly normalized Gaussian or just the terms that depend on value.

Returns:

  • log_prob ( Tensor ) –
Source code in laplace/baselaplace.py
def log_prob(self, value: torch.Tensor, normalized: bool = True) -> torch.Tensor:
    """Compute the log probability under the (current) Laplace approximation.

    Parameters
    ----------
    value: torch.Tensor
    normalized : bool, default=True
        whether to return log of a properly normalized Gaussian or just the
        terms that depend on `value`.

    Returns
    -------
    log_prob : torch.Tensor
    """
    if not normalized:
        return -self.square_norm(value) / 2
    log_prob = (
        -self.n_params / 2 * log(2 * pi) + self.log_det_posterior_precision / 2
    )
    log_prob -= self.square_norm(value) / 2
    return log_prob

predictive_samples #

predictive_samples(x: Tensor | MutableMapping[str, Tensor | Any], pred_type: PredType | str = GLM, n_samples: int = 100, diagonal_output: bool = False, generator: Generator | None = None) -> Tensor

Sample from the posterior predictive on input data x. Can be used, for example, for Thompson sampling.

Parameters:

  • x #

    (Tensor or MutableMapping) –

    input data (batch_size, input_shape)

  • pred_type #

    (('glm', 'nn'), default: 'glm' ) –

    type of posterior predictive, linearized GLM predictive or neural network sampling predictive. The GLM predictive is consistent with the curvature approximations used here.

  • n_samples #

    (int, default: 100 ) –

    number of samples

  • diagonal_output #

    (bool, default: False ) –

    whether to use a diagonalized glm posterior predictive on the outputs. Only applies when pred_type='glm'.

  • generator #

    (Generator, default: None ) –

    random number generator to control the samples (if sampling used)

Returns:

  • samples ( Tensor ) –

    samples (n_samples, batch_size, output_shape)

Source code in laplace/baselaplace.py
def predictive_samples(
    self,
    x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
    pred_type: PredType | str = PredType.GLM,
    n_samples: int = 100,
    diagonal_output: bool = False,
    generator: torch.Generator | None = None,
) -> torch.Tensor:
    """Sample from the posterior predictive on input data `x`.
    Can be used, for example, for Thompson sampling.

    Parameters
    ----------
    x : torch.Tensor or MutableMapping
        input data `(batch_size, input_shape)`

    pred_type : {'glm', 'nn'}, default='glm'
        type of posterior predictive, linearized GLM predictive or neural
        network sampling predictive. The GLM predictive is consistent with
        the curvature approximations used here.

    n_samples : int
        number of samples

    diagonal_output : bool
        whether to use a diagonalized glm posterior predictive on the outputs.
        Only applies when `pred_type='glm'`.

    generator : torch.Generator, optional
        random number generator to control the samples (if sampling used)

    Returns
    -------
    samples : torch.Tensor
        samples `(n_samples, batch_size, output_shape)`
    """
    if pred_type not in PredType.__members__.values():
        raise ValueError("Only glm and nn supported as prediction types.")

    if pred_type == PredType.GLM:
        f_mu, f_var = self._glm_predictive_distribution(x)
        return self._glm_predictive_samples(
            f_mu, f_var, n_samples, diagonal_output, generator
        )

    else:  # 'nn'
        return self._nn_predictive_samples(x, n_samples, generator)

functional_variance_fast #

functional_variance_fast(X)

Should be overriden if there exists a trick to make this fast!

Parameters:

  • X #

Returns:

  • f_var_diag ( torch.Tensor of shape (batch_size, num_outputs) ) –

    Corresponding to the diagonal of the covariance matrix of the outputs

Source code in laplace/lllaplace.py
def functional_variance_fast(self, X):
    """
    Should be overriden if there exists a trick to make this fast!

    Parameters
    ----------
    X: torch.Tensor of shape (batch_size, input_dim)

    Returns
    -------
    f_var_diag: torch.Tensor of shape (batch_size, num_outputs)
        Corresponding to the diagonal of the covariance matrix of the outputs
    """
    Js, f_mu = self.backend.last_layer_jacobians(X, self.enable_backprop)
    f_cov = self.functional_variance(Js)  # No trick possible for Full Laplace
    f_var = torch.diagonal(f_cov, dim1=-2, dim2=-1)
    return f_mu, f_var