Skip to content

laplace.baselaplace #

ParametricLaplace #

ParametricLaplace(model: Module, likelihood: Likelihood | str, sigma_noise: float | Tensor = 1.0, prior_precision: float | Tensor = 1.0, prior_mean: float | Tensor = 0.0, temperature: float = 1.0, enable_backprop: bool = False, dict_key_x: str = 'input_ids', dict_key_y: str = 'labels', backend: type[CurvatureInterface] | None = None, backend_kwargs: dict[str, Any] | None = None, asdl_fisher_kwargs: dict[str, Any] | None = None)

Bases: BaseLaplace

Parametric Laplace class.

Subclasses need to specify how the Hessian approximation is initialized, how to add up curvature over training data, how to sample from the Laplace approximation, and how to compute the functional variance.

A Laplace approximation is represented by a MAP which is given by the model parameter and a posterior precision or covariance specifying a Gaussian distribution \(\mathcal{N}(\theta_{MAP}, P^{-1})\). The goal of this class is to compute the posterior precision \(P\) which sums as

\[ P = \sum_{n=1}^N \nabla^2_\theta \log p(\mathcal{D}_n \mid \theta) \vert_{\theta_{MAP}} + \nabla^2_\theta \log p(\theta) \vert_{\theta_{MAP}}. \]

Every subclass implements different approximations to the log likelihood Hessians, for example, a diagonal one. The prior is assumed to be Gaussian and therefore we have a simple form for \(\nabla^2_\theta \log p(\theta) \vert_{\theta_{MAP}} = P_0 \). In particular, we assume a scalar, layer-wise, or diagonal prior precision so that in all cases \(P_0 = \textrm{diag}(p_0)\) and the structure of \(p_0\) can be varied.

Source code in laplace/baselaplace.py
def __init__(
    self,
    model: nn.Module,
    likelihood: Likelihood | str,
    sigma_noise: float | torch.Tensor = 1.0,
    prior_precision: float | torch.Tensor = 1.0,
    prior_mean: float | torch.Tensor = 0.0,
    temperature: float = 1.0,
    enable_backprop: bool = False,
    dict_key_x: str = "input_ids",
    dict_key_y: str = "labels",
    backend: type[CurvatureInterface] | None = None,
    backend_kwargs: dict[str, Any] | None = None,
    asdl_fisher_kwargs: dict[str, Any] | None = None,
):
    super().__init__(
        model,
        likelihood,
        sigma_noise,
        prior_precision,
        prior_mean,
        temperature,
        enable_backprop,
        dict_key_x,
        dict_key_y,
        backend,
        backend_kwargs,
        asdl_fisher_kwargs,
    )
    if not hasattr(self, "H"):
        self._init_H()
        # posterior mean/mode
        self.mean: float | torch.Tensor = self.prior_mean

log_likelihood #

log_likelihood: Tensor

Compute log likelihood on the training data after .fit() has been called. The log likelihood is computed on-demand based on the loss and, for example, the observation noise which makes it differentiable in the latter for iterative updates.

Returns:

  • log_likelihood ( Tensor ) –

prior_precision_diag #

prior_precision_diag: Tensor

Obtain the diagonal prior precision \(p_0\) constructed from either a scalar, layer-wise, or diagonal prior precision.

Returns:

  • prior_precision_diag ( Tensor ) –

scatter #

scatter: Tensor

Computes the scatter, a term of the log marginal likelihood that corresponds to L-2 regularization: scatter = \((\theta_{MAP} - \mu_0)^{T} P_0 (\theta_{MAP} - \mu_0) \).

Returns:

  • scatter ( Tensor ) –

log_det_prior_precision #

log_det_prior_precision: Tensor

Compute log determinant of the prior precision \(\log \det P_0\)

Returns:

  • log_det ( Tensor ) –

log_det_posterior_precision #

log_det_posterior_precision: Tensor

Compute log determinant of the posterior precision \(\log \det P\) which depends on the subclasses structure used for the Hessian approximation.

Returns:

  • log_det ( Tensor ) –

log_det_ratio #

log_det_ratio: Tensor

Compute the log determinant ratio, a part of the log marginal likelihood.

\[ \log \frac{\det P}{\det P_0} = \log \det P - \log \det P_0 \]

Returns:

  • log_det_ratio ( Tensor ) –

posterior_precision #

posterior_precision: Tensor

Compute or return the posterior precision \(P\).

Returns:

  • posterior_prec ( Tensor ) –

_glm_forward_call #

_glm_forward_call(x: Tensor | MutableMapping, likelihood: Likelihood | str, joint: bool = False, link_approx: LinkApprox | str = LinkApprox.PROBIT, n_samples: int = 100, diagonal_output: bool = False) -> Tensor | tuple[Tensor, Tensor]

Compute the posterior predictive on input data x for "glm" pred type.

Parameters:

  • x (Tensor or MutableMapping) –

    (batch_size, input_shape) if tensor. If MutableMapping, must contain the said tensor.

  • likelihood (Likelihood or str in {'classification', 'regression', 'reward_modeling'}) –

    determines the log likelihood Hessian approximation.

  • link_approx (('mc', 'probit', 'bridge', 'bridge_norm'), default: 'mc' ) –

    how to approximate the classification link function for the 'glm'. For pred_type='nn', only 'mc' is possible.

  • joint (bool, default: False ) –

    Whether to output a joint predictive distribution in regression with pred_type='glm'. If set to True, the predictive distribution has the same form as GP posterior, i.e. N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]). If False, then only outputs the marginal predictive distribution. Only available for regression and GLM predictive.

  • n_samples (int, default: 100 ) –

    number of samples for link_approx='mc'.

  • diagonal_output (bool, default: False ) –

    whether to use a diagonalized posterior predictive on the outputs. Only works for pred_type='glm' and link_approx='mc'.

Returns:

  • predictive ( Tensor or tuple[Tensor] ) –

    For likelihood='classification', a torch.Tensor is returned with a distribution over classes (similar to a Softmax). For likelihood='regression', a tuple of torch.Tensor is returned with the mean and the predictive variance. For likelihood='regression' and joint=True, a tuple of torch.Tensor is returned with the mean and the predictive covariance.

Source code in laplace/baselaplace.py
def _glm_forward_call(
    self,
    x: torch.Tensor | MutableMapping,
    likelihood: Likelihood | str,
    joint: bool = False,
    link_approx: LinkApprox | str = LinkApprox.PROBIT,
    n_samples: int = 100,
    diagonal_output: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
    """Compute the posterior predictive on input data `x` for "glm" pred type.

    Parameters
    ----------
    x : torch.Tensor or MutableMapping
        `(batch_size, input_shape)` if tensor. If MutableMapping, must contain
        the said tensor.

    likelihood : Likelihood or str in {'classification', 'regression', 'reward_modeling'}
        determines the log likelihood Hessian approximation.

    link_approx : {'mc', 'probit', 'bridge', 'bridge_norm'}
        how to approximate the classification link function for the `'glm'`.
        For `pred_type='nn'`, only 'mc' is possible.

    joint : bool
        Whether to output a joint predictive distribution in regression with
        `pred_type='glm'`. If set to `True`, the predictive distribution
        has the same form as GP posterior, i.e. N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]).
        If `False`, then only outputs the marginal predictive distribution.
        Only available for regression and GLM predictive.

    n_samples : int
        number of samples for `link_approx='mc'`.

    diagonal_output : bool
        whether to use a diagonalized posterior predictive on the outputs.
        Only works for `pred_type='glm'` and `link_approx='mc'`.

    Returns
    -------
    predictive: torch.Tensor or tuple[torch.Tensor]
        For `likelihood='classification'`, a torch.Tensor is returned with
        a distribution over classes (similar to a Softmax).
        For `likelihood='regression'`, a tuple of torch.Tensor is returned
        with the mean and the predictive variance.
        For `likelihood='regression'` and `joint=True`, a tuple of torch.Tensor
        is returned with the mean and the predictive covariance.
    """
    f_mu, f_var = self._glm_predictive_distribution(
        x, joint=joint and likelihood == Likelihood.REGRESSION
    )

    if likelihood == Likelihood.REGRESSION:
        if diagonal_output and not joint:
            f_var = torch.diagonal(f_var, dim1=-2, dim2=-1)
        return f_mu, f_var

    if link_approx == LinkApprox.MC:
        return self._glm_predictive_samples(
            f_mu,
            f_var,
            n_samples=n_samples,
            diagonal_output=diagonal_output,
        ).mean(dim=0)
    elif link_approx == LinkApprox.PROBIT:
        kappa = 1 / torch.sqrt(1.0 + np.pi / 8 * f_var.diagonal(dim1=1, dim2=2))
        return torch.softmax(kappa * f_mu, dim=-1)
    elif "bridge" in link_approx:
        # zero mean correction
        f_mu -= (
            f_var.sum(-1)
            * f_mu.sum(-1).reshape(-1, 1)
            / f_var.sum(dim=(1, 2)).reshape(-1, 1)
        )
        f_var -= torch.einsum(
            "bi,bj->bij", f_var.sum(-1), f_var.sum(-2)
        ) / f_var.sum(dim=(1, 2)).reshape(-1, 1, 1)

        # Laplace Bridge
        _, K = f_mu.size(0), f_mu.size(-1)
        f_var_diag = torch.diagonal(f_var, dim1=1, dim2=2)

        # optional: variance correction
        if link_approx == LinkApprox.BRIDGE_NORM:
            f_var_diag_mean = f_var_diag.mean(dim=1)
            f_var_diag_mean /= torch.as_tensor([K / 2], device=self._device).sqrt()
            f_mu /= f_var_diag_mean.sqrt().unsqueeze(-1)
            f_var_diag /= f_var_diag_mean.unsqueeze(-1)

        sum_exp = torch.exp(-f_mu).sum(dim=1).unsqueeze(-1)
        alpha = (1 - 2 / K + f_mu.exp() / K**2 * sum_exp) / f_var_diag
        return torch.nan_to_num(alpha / alpha.sum(dim=1).unsqueeze(-1), nan=1.0)
    else:
        raise ValueError(
            "Prediction path invalid. Check the likelihood, pred_type, link_approx combination!"
        )

_glm_predictive_samples #

_glm_predictive_samples(f_mu: Tensor, f_var: Tensor, n_samples: int, diagonal_output: bool = False, generator: Generator | None = None) -> Tensor

Sample from the posterior predictive on input data x using "glm" prediction type.

Parameters:

  • f_mu (Tensor or MutableMapping) –

    glm predictive mean (batch_size, output_shape)

  • f_var (Tensor or MutableMapping) –

    glm predictive covariances (batch_size, output_shape, output_shape)

  • n_samples (int) –

    number of samples

  • diagonal_output (bool, default: False ) –

    whether to use a diagonalized glm posterior predictive on the outputs.

  • generator (Generator, default: None ) –

    random number generator to control the samples (if sampling used)

Returns:

  • samples ( Tensor ) –

    samples (n_samples, batch_size, output_shape)

Source code in laplace/baselaplace.py
def _glm_predictive_samples(
    self,
    f_mu: torch.Tensor,
    f_var: torch.Tensor,
    n_samples: int,
    diagonal_output: bool = False,
    generator: torch.Generator | None = None,
) -> torch.Tensor:
    """Sample from the posterior predictive on input data `x` using "glm" prediction
    type.

    Parameters
    ----------
    f_mu : torch.Tensor or MutableMapping
        glm predictive mean `(batch_size, output_shape)`

    f_var : torch.Tensor or MutableMapping
        glm predictive covariances `(batch_size, output_shape, output_shape)`

    n_samples : int
        number of samples

    diagonal_output : bool
        whether to use a diagonalized glm posterior predictive on the outputs.

    generator : torch.Generator, optional
        random number generator to control the samples (if sampling used)

    Returns
    -------
    samples : torch.Tensor
        samples `(n_samples, batch_size, output_shape)`
    """
    assert f_var.shape == torch.Size([f_mu.shape[0], f_mu.shape[1], f_mu.shape[1]])

    if diagonal_output:
        f_var = torch.diagonal(f_var, dim1=1, dim2=2)

    f_samples = normal_samples(f_mu, f_var, n_samples, generator)

    if self.likelihood == Likelihood.REGRESSION:
        return f_samples
    else:
        return torch.softmax(f_samples, dim=-1)

fit #

fit(train_loader: DataLoader, override: bool = True, progress_bar: bool = False) -> None

Fit the local Laplace approximation at the parameters of the model.

Parameters:

  • train_loader (DataLoader) –

    each iterate is a training batch, either (X, y) tensors or a dict-like object containing keys as expressed by self.dict_key_x and self.dict_key_y. train_loader.dataset needs to be set to access \(N\), size of the data set.

  • override (bool, default: True ) –

    whether to initialize H, loss, and n_data again; setting to False is useful for online learning settings to accumulate a sequential posterior approximation.

  • progress_bar (bool, default: False ) –

    whether to show a progress bar; updated at every batch-Hessian computation. Useful for very large model and large amount of data, esp. when subset_of_weights='all'.

Source code in laplace/baselaplace.py
def fit(
    self,
    train_loader: DataLoader,
    override: bool = True,
    progress_bar: bool = False,
) -> None:
    """Fit the local Laplace approximation at the parameters of the model.

    Parameters
    ----------
    train_loader : torch.data.utils.DataLoader
        each iterate is a training batch, either `(X, y)` tensors or a dict-like
        object containing keys as expressed by `self.dict_key_x` and
        `self.dict_key_y`. `train_loader.dataset` needs to be set to access
        \\(N\\), size of the data set.
    override : bool, default=True
        whether to initialize H, loss, and n_data again; setting to False is useful for
        online learning settings to accumulate a sequential posterior approximation.
    progress_bar : bool, default=False
        whether to show a progress bar; updated at every batch-Hessian computation.
        Useful for very large model and large amount of data, esp. when `subset_of_weights='all'`.
    """
    if override:
        self._init_H()
        self.loss: float | torch.Tensor = 0
        self.n_data: int = 0

    self.model.eval()

    self.mean: torch.Tensor = parameters_to_vector(self.params)
    if not self.enable_backprop:
        self.mean = self.mean.detach()

    data: (
        tuple[torch.Tensor, torch.Tensor] | MutableMapping[str, torch.Tensor | Any]
    ) = next(iter(train_loader))

    with torch.no_grad():
        if isinstance(data, MutableMapping):  # To support Huggingface dataset
            if "backpack" in self._backend_cls.__name__.lower() or (
                isinstance(self, DiagLaplace) and self._backend_cls == CurvlinopsEF
            ):
                raise ValueError(
                    "Currently DiagEF is not supported under CurvlinopsEF backend "
                    + "for custom models with non-tensor inputs "
                    + "(https://github.com/pytorch/functorch/issues/159). Consider "
                    + "using AsdlEF backend instead. The same limitation applies "
                    + "to all BackPACK backend"
                )

            out = self.model(data)
        else:
            X = data[0]
            try:
                out = self.model(X[:1].to(self._device))
            except (TypeError, AttributeError):
                out = self.model(X.to(self._device))
    self.n_outputs = out.shape[-1]
    setattr(self.model, "output_size", self.n_outputs)

    N = len(train_loader.dataset)

    pbar = tqdm.tqdm(train_loader, disable=not progress_bar)
    pbar.set_description("[Computing Hessian]")

    for data in pbar:
        if isinstance(data, MutableMapping):  # To support Huggingface dataset
            X, y = data, data[self.dict_key_y].to(self._device)
        else:
            X, y = data
            X, y = X.to(self._device), y.to(self._device)

        if self.likelihood == Likelihood.REGRESSION and y.ndim != out.ndim:
            raise ValueError(
                f"The model's output has {out.ndim} dims but "
                f"the target has {y.ndim} dims."
            )

        self.model.zero_grad()
        loss_batch, H_batch = self._curv_closure(X, y, N=N)
        self.loss += loss_batch
        self.H += H_batch

    self.n_data += N

square_norm #

square_norm(value) -> Tensor

Compute the square norm under post. Precision with value-self.mean as 𝛥:

\[ \Delta^ op P \Delta \]

Returns:

  • square_form
Source code in laplace/baselaplace.py
def square_norm(self, value) -> torch.Tensor:
    """Compute the square norm under post. Precision with `value-self.mean` as 𝛥:

    $$
        \\Delta^\top P \\Delta
    $$

    Returns
    -------
    square_form
    """
    raise NotImplementedError

log_prob #

log_prob(value: Tensor, normalized: bool = True) -> Tensor

Compute the log probability under the (current) Laplace approximation.

Parameters:

  • value (Tensor) –
  • normalized (bool, default: True ) –

    whether to return log of a properly normalized Gaussian or just the terms that depend on value.

Returns:

  • log_prob ( Tensor ) –
Source code in laplace/baselaplace.py
def log_prob(self, value: torch.Tensor, normalized: bool = True) -> torch.Tensor:
    """Compute the log probability under the (current) Laplace approximation.

    Parameters
    ----------
    value: torch.Tensor
    normalized : bool, default=True
        whether to return log of a properly normalized Gaussian or just the
        terms that depend on `value`.

    Returns
    -------
    log_prob : torch.Tensor
    """
    if not normalized:
        return -self.square_norm(value) / 2
    log_prob = (
        -self.n_params / 2 * log(2 * pi) + self.log_det_posterior_precision / 2
    )
    log_prob -= self.square_norm(value) / 2
    return log_prob

log_marginal_likelihood #

log_marginal_likelihood(prior_precision: Tensor | None = None, sigma_noise: Tensor | None = None) -> Tensor

Compute the Laplace approximation to the log marginal likelihood subject to specific Hessian approximations that subclasses implement. Requires that the Laplace approximation has been fit before. The resulting torch.Tensor is differentiable in prior_precision and sigma_noise if these have gradients enabled. By passing prior_precision or sigma_noise, the current value is overwritten. This is useful for iterating on the log marginal likelihood.

Parameters:

  • prior_precision (Tensor, default: None ) –

    prior precision if should be changed from current prior_precision value

  • sigma_noise (Tensor, default: None ) –

    observation noise standard deviation if should be changed

Returns:

  • log_marglik ( Tensor ) –
Source code in laplace/baselaplace.py
def log_marginal_likelihood(
    self,
    prior_precision: torch.Tensor | None = None,
    sigma_noise: torch.Tensor | None = None,
) -> torch.Tensor:
    """Compute the Laplace approximation to the log marginal likelihood subject
    to specific Hessian approximations that subclasses implement.
    Requires that the Laplace approximation has been fit before.
    The resulting torch.Tensor is differentiable in `prior_precision` and
    `sigma_noise` if these have gradients enabled.
    By passing `prior_precision` or `sigma_noise`, the current value is
    overwritten. This is useful for iterating on the log marginal likelihood.

    Parameters
    ----------
    prior_precision : torch.Tensor, optional
        prior precision if should be changed from current `prior_precision` value
    sigma_noise : torch.Tensor, optional
        observation noise standard deviation if should be changed

    Returns
    -------
    log_marglik : torch.Tensor
    """
    # update prior precision (useful when iterating on marglik)
    if prior_precision is not None:
        self.prior_precision = prior_precision

    # update sigma_noise (useful when iterating on marglik)
    if sigma_noise is not None:
        if self.likelihood != Likelihood.REGRESSION:
            raise ValueError("Can only change sigma_noise for regression.")

        self.sigma_noise = sigma_noise

    return self.log_likelihood - 0.5 * (self.log_det_ratio + self.scatter)

__call__ #

__call__(x: Tensor | MutableMapping[str, Tensor | Any], pred_type: PredType | str = PredType.GLM, joint: bool = False, link_approx: LinkApprox | str = LinkApprox.PROBIT, n_samples: int = 100, diagonal_output: bool = False, generator: Generator | None = None, fitting: bool = False, **model_kwargs: dict[str, Any]) -> Tensor | tuple[Tensor, Tensor]

Compute the posterior predictive on input data x.

Parameters:

  • x (Tensor or MutableMapping) –

    (batch_size, input_shape) if tensor. If MutableMapping, must contain the said tensor.

  • pred_type (('glm', 'nn'), default: 'glm' ) –

    type of posterior predictive, linearized GLM predictive or neural network sampling predictive. The GLM predictive is consistent with the curvature approximations used here. When Laplace is done only on subset of parameters (i.e. some grad are disabled), only nn predictive is supported.

  • link_approx (('mc', 'probit', 'bridge', 'bridge_norm'), default: 'mc' ) –

    how to approximate the classification link function for the 'glm'. For pred_type='nn', only 'mc' is possible.

  • joint (bool, default: False ) –

    Whether to output a joint predictive distribution in regression with pred_type='glm'. If set to True, the predictive distribution has the same form as GP posterior, i.e. N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]). If False, then only outputs the marginal predictive distribution. Only available for regression and GLM predictive.

  • n_samples (int, default: 100 ) –

    number of samples for link_approx='mc'.

  • diagonal_output (bool, default: False ) –

    whether to use a diagonalized posterior predictive on the outputs. Only works for pred_type='glm' when joint=False in regression. In the case of last-layer Laplace with a diagonal or Kron Hessian, setting this to True makes computation much(!) faster for large number of outputs.

  • generator (Generator, default: None ) –

    random number generator to control the samples (if sampling used).

  • fitting (bool, default: False ) –

    whether or not this predictive call is done during fitting. Only useful for reward modeling: the likelihood is set to "regression" when False and "classification" when True.

Returns:

  • predictive ( Tensor or tuple[Tensor] ) –

    For likelihood='classification', a torch.Tensor is returned with a distribution over classes (similar to a Softmax). For likelihood='regression', a tuple of torch.Tensor is returned with the mean and the predictive variance. For likelihood='regression' and joint=True, a tuple of torch.Tensor is returned with the mean and the predictive covariance.

Source code in laplace/baselaplace.py
def __call__(
    self,
    x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
    pred_type: PredType | str = PredType.GLM,
    joint: bool = False,
    link_approx: LinkApprox | str = LinkApprox.PROBIT,
    n_samples: int = 100,
    diagonal_output: bool = False,
    generator: torch.Generator | None = None,
    fitting: bool = False,
    **model_kwargs: dict[str, Any],
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
    """Compute the posterior predictive on input data `x`.

    Parameters
    ----------
    x : torch.Tensor or MutableMapping
        `(batch_size, input_shape)` if tensor. If MutableMapping, must contain
        the said tensor.

    pred_type : {'glm', 'nn'}, default='glm'
        type of posterior predictive, linearized GLM predictive or neural
        network sampling predictive. The GLM predictive is consistent with
        the curvature approximations used here. When Laplace is done only
        on subset of parameters (i.e. some grad are disabled),
        only `nn` predictive is supported.

    link_approx : {'mc', 'probit', 'bridge', 'bridge_norm'}
        how to approximate the classification link function for the `'glm'`.
        For `pred_type='nn'`, only 'mc' is possible.

    joint : bool
        Whether to output a joint predictive distribution in regression with
        `pred_type='glm'`. If set to `True`, the predictive distribution
        has the same form as GP posterior, i.e. N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]).
        If `False`, then only outputs the marginal predictive distribution.
        Only available for regression and GLM predictive.

    n_samples : int
        number of samples for `link_approx='mc'`.

    diagonal_output : bool
        whether to use a diagonalized posterior predictive on the outputs.
        Only works for `pred_type='glm'` when `joint=False` in regression.
        In the case of last-layer Laplace with a diagonal or Kron Hessian,
        setting this to `True` makes computation much(!) faster for large
        number of outputs.

    generator : torch.Generator, optional
        random number generator to control the samples (if sampling used).

    fitting : bool, default=False
        whether or not this predictive call is done during fitting. Only useful for
        reward modeling: the likelihood is set to `"regression"` when `False` and
        `"classification"` when `True`.

    Returns
    -------
    predictive: torch.Tensor or tuple[torch.Tensor]
        For `likelihood='classification'`, a torch.Tensor is returned with
        a distribution over classes (similar to a Softmax).
        For `likelihood='regression'`, a tuple of torch.Tensor is returned
        with the mean and the predictive variance.
        For `likelihood='regression'` and `joint=True`, a tuple of torch.Tensor
        is returned with the mean and the predictive covariance.
    """
    if pred_type not in [pred for pred in PredType]:
        raise ValueError("Only glm and nn supported as prediction types.")

    if link_approx not in [la for la in LinkApprox]:
        raise ValueError(f"Unsupported link approximation {link_approx}.")

    if pred_type == PredType.NN and link_approx != LinkApprox.MC:
        raise ValueError(
            "Only mc link approximation is supported for nn prediction type."
        )

    if generator is not None:
        if (
            not isinstance(generator, torch.Generator)
            or generator.device != self._device
        ):
            raise ValueError("Invalid random generator (check type and device).")

    likelihood = self.likelihood
    if likelihood == Likelihood.REWARD_MODELING:
        likelihood = Likelihood.CLASSIFICATION if fitting else Likelihood.REGRESSION

    if pred_type == PredType.GLM:
        return self._glm_forward_call(
            x, likelihood, joint, link_approx, n_samples, diagonal_output
        )
    else:
        if likelihood == Likelihood.REGRESSION:
            samples = self._nn_predictive_samples(x, n_samples, **model_kwargs)
            return samples.mean(dim=0), samples.var(dim=0)
        else:  # classification; the average is computed online
            return self._nn_predictive_classification(x, n_samples, **model_kwargs)

predictive_samples #

predictive_samples(x: Tensor | MutableMapping[str, Tensor | Any], pred_type: PredType | str = PredType.GLM, n_samples: int = 100, diagonal_output: bool = False, generator: Generator | None = None) -> Tensor

Sample from the posterior predictive on input data x. Can be used, for example, for Thompson sampling.

Parameters:

  • x (Tensor or MutableMapping) –

    input data (batch_size, input_shape)

  • pred_type (('glm', 'nn'), default: 'glm' ) –

    type of posterior predictive, linearized GLM predictive or neural network sampling predictive. The GLM predictive is consistent with the curvature approximations used here.

  • n_samples (int, default: 100 ) –

    number of samples

  • diagonal_output (bool, default: False ) –

    whether to use a diagonalized glm posterior predictive on the outputs. Only applies when pred_type='glm'.

  • generator (Generator, default: None ) –

    random number generator to control the samples (if sampling used)

Returns:

  • samples ( Tensor ) –

    samples (n_samples, batch_size, output_shape)

Source code in laplace/baselaplace.py
def predictive_samples(
    self,
    x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
    pred_type: PredType | str = PredType.GLM,
    n_samples: int = 100,
    diagonal_output: bool = False,
    generator: torch.Generator | None = None,
) -> torch.Tensor:
    """Sample from the posterior predictive on input data `x`.
    Can be used, for example, for Thompson sampling.

    Parameters
    ----------
    x : torch.Tensor or MutableMapping
        input data `(batch_size, input_shape)`

    pred_type : {'glm', 'nn'}, default='glm'
        type of posterior predictive, linearized GLM predictive or neural
        network sampling predictive. The GLM predictive is consistent with
        the curvature approximations used here.

    n_samples : int
        number of samples

    diagonal_output : bool
        whether to use a diagonalized glm posterior predictive on the outputs.
        Only applies when `pred_type='glm'`.

    generator : torch.Generator, optional
        random number generator to control the samples (if sampling used)

    Returns
    -------
    samples : torch.Tensor
        samples `(n_samples, batch_size, output_shape)`
    """
    if pred_type not in PredType.__members__.values():
        raise ValueError("Only glm and nn supported as prediction types.")

    if pred_type == PredType.GLM:
        f_mu, f_var = self._glm_predictive_distribution(x)
        return self._glm_predictive_samples(
            f_mu, f_var, n_samples, diagonal_output, generator
        )

    else:  # 'nn'
        return self._nn_predictive_samples(x, n_samples, generator)

functional_variance #

functional_variance(Js: Tensor) -> Tensor

Compute functional variance for the 'glm' predictive: f_var[i] = Js[i] @ P.inv() @ Js[i].T, which is a output x output predictive covariance matrix. Mathematically, we have for a single Jacobian \(\mathcal{J} = \nabla_\theta f(x;\theta)\vert_{\theta_{MAP}}\) the output covariance matrix \( \mathcal{J} P^{-1} \mathcal{J}^T \).

Parameters:

  • Js (Tensor) –

    Jacobians of model output wrt parameters (batch, outputs, parameters)

Returns:

  • f_var ( Tensor ) –

    output covariance (batch, outputs, outputs)

Source code in laplace/baselaplace.py
def functional_variance(self, Js: torch.Tensor) -> torch.Tensor:
    """Compute functional variance for the `'glm'` predictive:
    `f_var[i] = Js[i] @ P.inv() @ Js[i].T`, which is a output x output
    predictive covariance matrix.
    Mathematically, we have for a single Jacobian
    \\(\\mathcal{J} = \\nabla_\\theta f(x;\\theta)\\vert_{\\theta_{MAP}}\\)
    the output covariance matrix
    \\( \\mathcal{J} P^{-1} \\mathcal{J}^T \\).

    Parameters
    ----------
    Js : torch.Tensor
        Jacobians of model output wrt parameters
        `(batch, outputs, parameters)`

    Returns
    -------
    f_var : torch.Tensor
        output covariance `(batch, outputs, outputs)`
    """
    raise NotImplementedError

functional_covariance #

functional_covariance(Js: Tensor) -> Tensor

Compute functional covariance for the 'glm' predictive: f_cov = Js @ P.inv() @ Js.T, which is a batchoutput x batchoutput predictive covariance matrix.

This emulates the GP posterior covariance N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]). Useful for joint predictions, such as in batched Bayesian optimization.

Parameters:

  • Js (Tensor) –

    Jacobians of model output wrt parameters (batch*outputs, parameters)

Returns:

  • f_cov ( Tensor ) –

    output covariance (batch*outputs, batch*outputs)

Source code in laplace/baselaplace.py
def functional_covariance(self, Js: torch.Tensor) -> torch.Tensor:
    """Compute functional covariance for the `'glm'` predictive:
    `f_cov = Js @ P.inv() @ Js.T`, which is a batch*output x batch*output
    predictive covariance matrix.

    This emulates the GP posterior covariance N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]).
    Useful for joint predictions, such as in batched Bayesian optimization.

    Parameters
    ----------
    Js : torch.Tensor
        Jacobians of model output wrt parameters
        `(batch*outputs, parameters)`

    Returns
    -------
    f_cov : torch.Tensor
        output covariance `(batch*outputs, batch*outputs)`
    """
    raise NotImplementedError

sample #

sample(n_samples: int = 100, generator: Generator | None = None) -> Tensor

Sample from the Laplace posterior approximation, i.e., \( \theta \sim \mathcal{N}(\theta_{MAP}, P^{-1})\).

Parameters:

  • n_samples (int, default: 100 ) –

    number of samples

  • generator (Generator, default: None ) –

    random number generator to control the samples

Returns:

  • samples ( Tensor ) –
Source code in laplace/baselaplace.py
def sample(
    self, n_samples: int = 100, generator: torch.Generator | None = None
) -> torch.Tensor:
    """Sample from the Laplace posterior approximation, i.e.,
    \\( \\theta \\sim \\mathcal{N}(\\theta_{MAP}, P^{-1})\\).

    Parameters
    ----------
    n_samples : int, default=100
        number of samples

    generator : torch.Generator, optional
        random number generator to control the samples

    Returns
    -------
    samples: torch.Tensor
    """
    raise NotImplementedError

DiagLaplace #

DiagLaplace(model: Module, likelihood: Likelihood | str, sigma_noise: float | Tensor = 1.0, prior_precision: float | Tensor = 1.0, prior_mean: float | Tensor = 0.0, temperature: float = 1.0, enable_backprop: bool = False, dict_key_x: str = 'input_ids', dict_key_y: str = 'labels', backend: type[CurvatureInterface] | None = None, backend_kwargs: dict[str, Any] | None = None, asdl_fisher_kwargs: dict[str, Any] | None = None)

Bases: ParametricLaplace

Laplace approximation with diagonal log likelihood Hessian approximation and hence posterior precision. Mathematically, we have \(P \approx \textrm{diag}(P)\). See BaseLaplace for the full interface.

Source code in laplace/baselaplace.py
def __init__(
    self,
    model: nn.Module,
    likelihood: Likelihood | str,
    sigma_noise: float | torch.Tensor = 1.0,
    prior_precision: float | torch.Tensor = 1.0,
    prior_mean: float | torch.Tensor = 0.0,
    temperature: float = 1.0,
    enable_backprop: bool = False,
    dict_key_x: str = "input_ids",
    dict_key_y: str = "labels",
    backend: type[CurvatureInterface] | None = None,
    backend_kwargs: dict[str, Any] | None = None,
    asdl_fisher_kwargs: dict[str, Any] | None = None,
):
    super().__init__(
        model,
        likelihood,
        sigma_noise,
        prior_precision,
        prior_mean,
        temperature,
        enable_backprop,
        dict_key_x,
        dict_key_y,
        backend,
        backend_kwargs,
        asdl_fisher_kwargs,
    )
    if not hasattr(self, "H"):
        self._init_H()
        # posterior mean/mode
        self.mean: float | torch.Tensor = self.prior_mean

log_likelihood #

log_likelihood: Tensor

Compute log likelihood on the training data after .fit() has been called. The log likelihood is computed on-demand based on the loss and, for example, the observation noise which makes it differentiable in the latter for iterative updates.

Returns:

  • log_likelihood ( Tensor ) –

prior_precision_diag #

prior_precision_diag: Tensor

Obtain the diagonal prior precision \(p_0\) constructed from either a scalar, layer-wise, or diagonal prior precision.

Returns:

  • prior_precision_diag ( Tensor ) –

scatter #

scatter: Tensor

Computes the scatter, a term of the log marginal likelihood that corresponds to L-2 regularization: scatter = \((\theta_{MAP} - \mu_0)^{T} P_0 (\theta_{MAP} - \mu_0) \).

Returns:

  • scatter ( Tensor ) –

log_det_prior_precision #

log_det_prior_precision: Tensor

Compute log determinant of the prior precision \(\log \det P_0\)

Returns:

  • log_det ( Tensor ) –

log_det_ratio #

log_det_ratio: Tensor

Compute the log determinant ratio, a part of the log marginal likelihood.

\[ \log \frac{\det P}{\det P_0} = \log \det P - \log \det P_0 \]

Returns:

  • log_det_ratio ( Tensor ) –

posterior_precision #

posterior_precision: Tensor

Diagonal posterior precision \(p\).

Returns:

  • precision ( tensor ) –

    (parameters)

posterior_scale #

posterior_scale: Tensor

Diagonal posterior scale \(\sqrt{p^{-1}}\).

Returns:

  • precision ( tensor ) –

    (parameters)

posterior_variance #

posterior_variance: Tensor

Diagonal posterior variance \(p^{-1}\).

Returns:

  • precision ( tensor ) –

    (parameters)

fit #

fit(train_loader: DataLoader, override: bool = True, progress_bar: bool = False) -> None

Fit the local Laplace approximation at the parameters of the model.

Parameters:

  • train_loader (DataLoader) –

    each iterate is a training batch, either (X, y) tensors or a dict-like object containing keys as expressed by self.dict_key_x and self.dict_key_y. train_loader.dataset needs to be set to access \(N\), size of the data set.

  • override (bool, default: True ) –

    whether to initialize H, loss, and n_data again; setting to False is useful for online learning settings to accumulate a sequential posterior approximation.

  • progress_bar (bool, default: False ) –

    whether to show a progress bar; updated at every batch-Hessian computation. Useful for very large model and large amount of data, esp. when subset_of_weights='all'.

Source code in laplace/baselaplace.py
def fit(
    self,
    train_loader: DataLoader,
    override: bool = True,
    progress_bar: bool = False,
) -> None:
    """Fit the local Laplace approximation at the parameters of the model.

    Parameters
    ----------
    train_loader : torch.data.utils.DataLoader
        each iterate is a training batch, either `(X, y)` tensors or a dict-like
        object containing keys as expressed by `self.dict_key_x` and
        `self.dict_key_y`. `train_loader.dataset` needs to be set to access
        \\(N\\), size of the data set.
    override : bool, default=True
        whether to initialize H, loss, and n_data again; setting to False is useful for
        online learning settings to accumulate a sequential posterior approximation.
    progress_bar : bool, default=False
        whether to show a progress bar; updated at every batch-Hessian computation.
        Useful for very large model and large amount of data, esp. when `subset_of_weights='all'`.
    """
    if override:
        self._init_H()
        self.loss: float | torch.Tensor = 0
        self.n_data: int = 0

    self.model.eval()

    self.mean: torch.Tensor = parameters_to_vector(self.params)
    if not self.enable_backprop:
        self.mean = self.mean.detach()

    data: (
        tuple[torch.Tensor, torch.Tensor] | MutableMapping[str, torch.Tensor | Any]
    ) = next(iter(train_loader))

    with torch.no_grad():
        if isinstance(data, MutableMapping):  # To support Huggingface dataset
            if "backpack" in self._backend_cls.__name__.lower() or (
                isinstance(self, DiagLaplace) and self._backend_cls == CurvlinopsEF
            ):
                raise ValueError(
                    "Currently DiagEF is not supported under CurvlinopsEF backend "
                    + "for custom models with non-tensor inputs "
                    + "(https://github.com/pytorch/functorch/issues/159). Consider "
                    + "using AsdlEF backend instead. The same limitation applies "
                    + "to all BackPACK backend"
                )

            out = self.model(data)
        else:
            X = data[0]
            try:
                out = self.model(X[:1].to(self._device))
            except (TypeError, AttributeError):
                out = self.model(X.to(self._device))
    self.n_outputs = out.shape[-1]
    setattr(self.model, "output_size", self.n_outputs)

    N = len(train_loader.dataset)

    pbar = tqdm.tqdm(train_loader, disable=not progress_bar)
    pbar.set_description("[Computing Hessian]")

    for data in pbar:
        if isinstance(data, MutableMapping):  # To support Huggingface dataset
            X, y = data, data[self.dict_key_y].to(self._device)
        else:
            X, y = data
            X, y = X.to(self._device), y.to(self._device)

        if self.likelihood == Likelihood.REGRESSION and y.ndim != out.ndim:
            raise ValueError(
                f"The model's output has {out.ndim} dims but "
                f"the target has {y.ndim} dims."
            )

        self.model.zero_grad()
        loss_batch, H_batch = self._curv_closure(X, y, N=N)
        self.loss += loss_batch
        self.H += H_batch

    self.n_data += N

log_marginal_likelihood #

log_marginal_likelihood(prior_precision: Tensor | None = None, sigma_noise: Tensor | None = None) -> Tensor

Compute the Laplace approximation to the log marginal likelihood subject to specific Hessian approximations that subclasses implement. Requires that the Laplace approximation has been fit before. The resulting torch.Tensor is differentiable in prior_precision and sigma_noise if these have gradients enabled. By passing prior_precision or sigma_noise, the current value is overwritten. This is useful for iterating on the log marginal likelihood.

Parameters:

  • prior_precision (Tensor, default: None ) –

    prior precision if should be changed from current prior_precision value

  • sigma_noise (Tensor, default: None ) –

    observation noise standard deviation if should be changed

Returns:

  • log_marglik ( Tensor ) –
Source code in laplace/baselaplace.py
def log_marginal_likelihood(
    self,
    prior_precision: torch.Tensor | None = None,
    sigma_noise: torch.Tensor | None = None,
) -> torch.Tensor:
    """Compute the Laplace approximation to the log marginal likelihood subject
    to specific Hessian approximations that subclasses implement.
    Requires that the Laplace approximation has been fit before.
    The resulting torch.Tensor is differentiable in `prior_precision` and
    `sigma_noise` if these have gradients enabled.
    By passing `prior_precision` or `sigma_noise`, the current value is
    overwritten. This is useful for iterating on the log marginal likelihood.

    Parameters
    ----------
    prior_precision : torch.Tensor, optional
        prior precision if should be changed from current `prior_precision` value
    sigma_noise : torch.Tensor, optional
        observation noise standard deviation if should be changed

    Returns
    -------
    log_marglik : torch.Tensor
    """
    # update prior precision (useful when iterating on marglik)
    if prior_precision is not None:
        self.prior_precision = prior_precision

    # update sigma_noise (useful when iterating on marglik)
    if sigma_noise is not None:
        if self.likelihood != Likelihood.REGRESSION:
            raise ValueError("Can only change sigma_noise for regression.")

        self.sigma_noise = sigma_noise

    return self.log_likelihood - 0.5 * (self.log_det_ratio + self.scatter)

__call__ #

__call__(x: Tensor | MutableMapping[str, Tensor | Any], pred_type: PredType | str = PredType.GLM, joint: bool = False, link_approx: LinkApprox | str = LinkApprox.PROBIT, n_samples: int = 100, diagonal_output: bool = False, generator: Generator | None = None, fitting: bool = False, **model_kwargs: dict[str, Any]) -> Tensor | tuple[Tensor, Tensor]

Compute the posterior predictive on input data x.

Parameters:

  • x (Tensor or MutableMapping) –

    (batch_size, input_shape) if tensor. If MutableMapping, must contain the said tensor.

  • pred_type (('glm', 'nn'), default: 'glm' ) –

    type of posterior predictive, linearized GLM predictive or neural network sampling predictive. The GLM predictive is consistent with the curvature approximations used here. When Laplace is done only on subset of parameters (i.e. some grad are disabled), only nn predictive is supported.

  • link_approx (('mc', 'probit', 'bridge', 'bridge_norm'), default: 'mc' ) –

    how to approximate the classification link function for the 'glm'. For pred_type='nn', only 'mc' is possible.

  • joint (bool, default: False ) –

    Whether to output a joint predictive distribution in regression with pred_type='glm'. If set to True, the predictive distribution has the same form as GP posterior, i.e. N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]). If False, then only outputs the marginal predictive distribution. Only available for regression and GLM predictive.

  • n_samples (int, default: 100 ) –

    number of samples for link_approx='mc'.

  • diagonal_output (bool, default: False ) –

    whether to use a diagonalized posterior predictive on the outputs. Only works for pred_type='glm' when joint=False in regression. In the case of last-layer Laplace with a diagonal or Kron Hessian, setting this to True makes computation much(!) faster for large number of outputs.

  • generator (Generator, default: None ) –

    random number generator to control the samples (if sampling used).

  • fitting (bool, default: False ) –

    whether or not this predictive call is done during fitting. Only useful for reward modeling: the likelihood is set to "regression" when False and "classification" when True.

Returns:

  • predictive ( Tensor or tuple[Tensor] ) –

    For likelihood='classification', a torch.Tensor is returned with a distribution over classes (similar to a Softmax). For likelihood='regression', a tuple of torch.Tensor is returned with the mean and the predictive variance. For likelihood='regression' and joint=True, a tuple of torch.Tensor is returned with the mean and the predictive covariance.

Source code in laplace/baselaplace.py
def __call__(
    self,
    x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
    pred_type: PredType | str = PredType.GLM,
    joint: bool = False,
    link_approx: LinkApprox | str = LinkApprox.PROBIT,
    n_samples: int = 100,
    diagonal_output: bool = False,
    generator: torch.Generator | None = None,
    fitting: bool = False,
    **model_kwargs: dict[str, Any],
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
    """Compute the posterior predictive on input data `x`.

    Parameters
    ----------
    x : torch.Tensor or MutableMapping
        `(batch_size, input_shape)` if tensor. If MutableMapping, must contain
        the said tensor.

    pred_type : {'glm', 'nn'}, default='glm'
        type of posterior predictive, linearized GLM predictive or neural
        network sampling predictive. The GLM predictive is consistent with
        the curvature approximations used here. When Laplace is done only
        on subset of parameters (i.e. some grad are disabled),
        only `nn` predictive is supported.

    link_approx : {'mc', 'probit', 'bridge', 'bridge_norm'}
        how to approximate the classification link function for the `'glm'`.
        For `pred_type='nn'`, only 'mc' is possible.

    joint : bool
        Whether to output a joint predictive distribution in regression with
        `pred_type='glm'`. If set to `True`, the predictive distribution
        has the same form as GP posterior, i.e. N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]).
        If `False`, then only outputs the marginal predictive distribution.
        Only available for regression and GLM predictive.

    n_samples : int
        number of samples for `link_approx='mc'`.

    diagonal_output : bool
        whether to use a diagonalized posterior predictive on the outputs.
        Only works for `pred_type='glm'` when `joint=False` in regression.
        In the case of last-layer Laplace with a diagonal or Kron Hessian,
        setting this to `True` makes computation much(!) faster for large
        number of outputs.

    generator : torch.Generator, optional
        random number generator to control the samples (if sampling used).

    fitting : bool, default=False
        whether or not this predictive call is done during fitting. Only useful for
        reward modeling: the likelihood is set to `"regression"` when `False` and
        `"classification"` when `True`.

    Returns
    -------
    predictive: torch.Tensor or tuple[torch.Tensor]
        For `likelihood='classification'`, a torch.Tensor is returned with
        a distribution over classes (similar to a Softmax).
        For `likelihood='regression'`, a tuple of torch.Tensor is returned
        with the mean and the predictive variance.
        For `likelihood='regression'` and `joint=True`, a tuple of torch.Tensor
        is returned with the mean and the predictive covariance.
    """
    if pred_type not in [pred for pred in PredType]:
        raise ValueError("Only glm and nn supported as prediction types.")

    if link_approx not in [la for la in LinkApprox]:
        raise ValueError(f"Unsupported link approximation {link_approx}.")

    if pred_type == PredType.NN and link_approx != LinkApprox.MC:
        raise ValueError(
            "Only mc link approximation is supported for nn prediction type."
        )

    if generator is not None:
        if (
            not isinstance(generator, torch.Generator)
            or generator.device != self._device
        ):
            raise ValueError("Invalid random generator (check type and device).")

    likelihood = self.likelihood
    if likelihood == Likelihood.REWARD_MODELING:
        likelihood = Likelihood.CLASSIFICATION if fitting else Likelihood.REGRESSION

    if pred_type == PredType.GLM:
        return self._glm_forward_call(
            x, likelihood, joint, link_approx, n_samples, diagonal_output
        )
    else:
        if likelihood == Likelihood.REGRESSION:
            samples = self._nn_predictive_samples(x, n_samples, **model_kwargs)
            return samples.mean(dim=0), samples.var(dim=0)
        else:  # classification; the average is computed online
            return self._nn_predictive_classification(x, n_samples, **model_kwargs)

_glm_forward_call #

_glm_forward_call(x: Tensor | MutableMapping, likelihood: Likelihood | str, joint: bool = False, link_approx: LinkApprox | str = LinkApprox.PROBIT, n_samples: int = 100, diagonal_output: bool = False) -> Tensor | tuple[Tensor, Tensor]

Compute the posterior predictive on input data x for "glm" pred type.

Parameters:

  • x (Tensor or MutableMapping) –

    (batch_size, input_shape) if tensor. If MutableMapping, must contain the said tensor.

  • likelihood (Likelihood or str in {'classification', 'regression', 'reward_modeling'}) –

    determines the log likelihood Hessian approximation.

  • link_approx (('mc', 'probit', 'bridge', 'bridge_norm'), default: 'mc' ) –

    how to approximate the classification link function for the 'glm'. For pred_type='nn', only 'mc' is possible.

  • joint (bool, default: False ) –

    Whether to output a joint predictive distribution in regression with pred_type='glm'. If set to True, the predictive distribution has the same form as GP posterior, i.e. N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]). If False, then only outputs the marginal predictive distribution. Only available for regression and GLM predictive.

  • n_samples (int, default: 100 ) –

    number of samples for link_approx='mc'.

  • diagonal_output (bool, default: False ) –

    whether to use a diagonalized posterior predictive on the outputs. Only works for pred_type='glm' and link_approx='mc'.

Returns:

  • predictive ( Tensor or tuple[Tensor] ) –

    For likelihood='classification', a torch.Tensor is returned with a distribution over classes (similar to a Softmax). For likelihood='regression', a tuple of torch.Tensor is returned with the mean and the predictive variance. For likelihood='regression' and joint=True, a tuple of torch.Tensor is returned with the mean and the predictive covariance.

Source code in laplace/baselaplace.py
def _glm_forward_call(
    self,
    x: torch.Tensor | MutableMapping,
    likelihood: Likelihood | str,
    joint: bool = False,
    link_approx: LinkApprox | str = LinkApprox.PROBIT,
    n_samples: int = 100,
    diagonal_output: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
    """Compute the posterior predictive on input data `x` for "glm" pred type.

    Parameters
    ----------
    x : torch.Tensor or MutableMapping
        `(batch_size, input_shape)` if tensor. If MutableMapping, must contain
        the said tensor.

    likelihood : Likelihood or str in {'classification', 'regression', 'reward_modeling'}
        determines the log likelihood Hessian approximation.

    link_approx : {'mc', 'probit', 'bridge', 'bridge_norm'}
        how to approximate the classification link function for the `'glm'`.
        For `pred_type='nn'`, only 'mc' is possible.

    joint : bool
        Whether to output a joint predictive distribution in regression with
        `pred_type='glm'`. If set to `True`, the predictive distribution
        has the same form as GP posterior, i.e. N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]).
        If `False`, then only outputs the marginal predictive distribution.
        Only available for regression and GLM predictive.

    n_samples : int
        number of samples for `link_approx='mc'`.

    diagonal_output : bool
        whether to use a diagonalized posterior predictive on the outputs.
        Only works for `pred_type='glm'` and `link_approx='mc'`.

    Returns
    -------
    predictive: torch.Tensor or tuple[torch.Tensor]
        For `likelihood='classification'`, a torch.Tensor is returned with
        a distribution over classes (similar to a Softmax).
        For `likelihood='regression'`, a tuple of torch.Tensor is returned
        with the mean and the predictive variance.
        For `likelihood='regression'` and `joint=True`, a tuple of torch.Tensor
        is returned with the mean and the predictive covariance.
    """
    f_mu, f_var = self._glm_predictive_distribution(
        x, joint=joint and likelihood == Likelihood.REGRESSION
    )

    if likelihood == Likelihood.REGRESSION:
        if diagonal_output and not joint:
            f_var = torch.diagonal(f_var, dim1=-2, dim2=-1)
        return f_mu, f_var

    if link_approx == LinkApprox.MC:
        return self._glm_predictive_samples(
            f_mu,
            f_var,
            n_samples=n_samples,
            diagonal_output=diagonal_output,
        ).mean(dim=0)
    elif link_approx == LinkApprox.PROBIT:
        kappa = 1 / torch.sqrt(1.0 + np.pi / 8 * f_var.diagonal(dim1=1, dim2=2))
        return torch.softmax(kappa * f_mu, dim=-1)
    elif "bridge" in link_approx:
        # zero mean correction
        f_mu -= (
            f_var.sum(-1)
            * f_mu.sum(-1).reshape(-1, 1)
            / f_var.sum(dim=(1, 2)).reshape(-1, 1)
        )
        f_var -= torch.einsum(
            "bi,bj->bij", f_var.sum(-1), f_var.sum(-2)
        ) / f_var.sum(dim=(1, 2)).reshape(-1, 1, 1)

        # Laplace Bridge
        _, K = f_mu.size(0), f_mu.size(-1)
        f_var_diag = torch.diagonal(f_var, dim1=1, dim2=2)

        # optional: variance correction
        if link_approx == LinkApprox.BRIDGE_NORM:
            f_var_diag_mean = f_var_diag.mean(dim=1)
            f_var_diag_mean /= torch.as_tensor([K / 2], device=self._device).sqrt()
            f_mu /= f_var_diag_mean.sqrt().unsqueeze(-1)
            f_var_diag /= f_var_diag_mean.unsqueeze(-1)

        sum_exp = torch.exp(-f_mu).sum(dim=1).unsqueeze(-1)
        alpha = (1 - 2 / K + f_mu.exp() / K**2 * sum_exp) / f_var_diag
        return torch.nan_to_num(alpha / alpha.sum(dim=1).unsqueeze(-1), nan=1.0)
    else:
        raise ValueError(
            "Prediction path invalid. Check the likelihood, pred_type, link_approx combination!"
        )

_glm_predictive_samples #

_glm_predictive_samples(f_mu: Tensor, f_var: Tensor, n_samples: int, diagonal_output: bool = False, generator: Generator | None = None) -> Tensor

Sample from the posterior predictive on input data x using "glm" prediction type.

Parameters:

  • f_mu (Tensor or MutableMapping) –

    glm predictive mean (batch_size, output_shape)

  • f_var (Tensor or MutableMapping) –

    glm predictive covariances (batch_size, output_shape, output_shape)

  • n_samples (int) –

    number of samples

  • diagonal_output (bool, default: False ) –

    whether to use a diagonalized glm posterior predictive on the outputs.

  • generator (Generator, default: None ) –

    random number generator to control the samples (if sampling used)

Returns:

  • samples ( Tensor ) –

    samples (n_samples, batch_size, output_shape)

Source code in laplace/baselaplace.py
def _glm_predictive_samples(
    self,
    f_mu: torch.Tensor,
    f_var: torch.Tensor,
    n_samples: int,
    diagonal_output: bool = False,
    generator: torch.Generator | None = None,
) -> torch.Tensor:
    """Sample from the posterior predictive on input data `x` using "glm" prediction
    type.

    Parameters
    ----------
    f_mu : torch.Tensor or MutableMapping
        glm predictive mean `(batch_size, output_shape)`

    f_var : torch.Tensor or MutableMapping
        glm predictive covariances `(batch_size, output_shape, output_shape)`

    n_samples : int
        number of samples

    diagonal_output : bool
        whether to use a diagonalized glm posterior predictive on the outputs.

    generator : torch.Generator, optional
        random number generator to control the samples (if sampling used)

    Returns
    -------
    samples : torch.Tensor
        samples `(n_samples, batch_size, output_shape)`
    """
    assert f_var.shape == torch.Size([f_mu.shape[0], f_mu.shape[1], f_mu.shape[1]])

    if diagonal_output:
        f_var = torch.diagonal(f_var, dim1=1, dim2=2)

    f_samples = normal_samples(f_mu, f_var, n_samples, generator)

    if self.likelihood == Likelihood.REGRESSION:
        return f_samples
    else:
        return torch.softmax(f_samples, dim=-1)

log_prob #

log_prob(value: Tensor, normalized: bool = True) -> Tensor

Compute the log probability under the (current) Laplace approximation.

Parameters:

  • value (Tensor) –
  • normalized (bool, default: True ) –

    whether to return log of a properly normalized Gaussian or just the terms that depend on value.

Returns:

  • log_prob ( Tensor ) –
Source code in laplace/baselaplace.py
def log_prob(self, value: torch.Tensor, normalized: bool = True) -> torch.Tensor:
    """Compute the log probability under the (current) Laplace approximation.

    Parameters
    ----------
    value: torch.Tensor
    normalized : bool, default=True
        whether to return log of a properly normalized Gaussian or just the
        terms that depend on `value`.

    Returns
    -------
    log_prob : torch.Tensor
    """
    if not normalized:
        return -self.square_norm(value) / 2
    log_prob = (
        -self.n_params / 2 * log(2 * pi) + self.log_det_posterior_precision / 2
    )
    log_prob -= self.square_norm(value) / 2
    return log_prob

predictive_samples #

predictive_samples(x: Tensor | MutableMapping[str, Tensor | Any], pred_type: PredType | str = PredType.GLM, n_samples: int = 100, diagonal_output: bool = False, generator: Generator | None = None) -> Tensor

Sample from the posterior predictive on input data x. Can be used, for example, for Thompson sampling.

Parameters:

  • x (Tensor or MutableMapping) –

    input data (batch_size, input_shape)

  • pred_type (('glm', 'nn'), default: 'glm' ) –

    type of posterior predictive, linearized GLM predictive or neural network sampling predictive. The GLM predictive is consistent with the curvature approximations used here.

  • n_samples (int, default: 100 ) –

    number of samples

  • diagonal_output (bool, default: False ) –

    whether to use a diagonalized glm posterior predictive on the outputs. Only applies when pred_type='glm'.

  • generator (Generator, default: None ) –

    random number generator to control the samples (if sampling used)

Returns:

  • samples ( Tensor ) –

    samples (n_samples, batch_size, output_shape)

Source code in laplace/baselaplace.py
def predictive_samples(
    self,
    x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
    pred_type: PredType | str = PredType.GLM,
    n_samples: int = 100,
    diagonal_output: bool = False,
    generator: torch.Generator | None = None,
) -> torch.Tensor:
    """Sample from the posterior predictive on input data `x`.
    Can be used, for example, for Thompson sampling.

    Parameters
    ----------
    x : torch.Tensor or MutableMapping
        input data `(batch_size, input_shape)`

    pred_type : {'glm', 'nn'}, default='glm'
        type of posterior predictive, linearized GLM predictive or neural
        network sampling predictive. The GLM predictive is consistent with
        the curvature approximations used here.

    n_samples : int
        number of samples

    diagonal_output : bool
        whether to use a diagonalized glm posterior predictive on the outputs.
        Only applies when `pred_type='glm'`.

    generator : torch.Generator, optional
        random number generator to control the samples (if sampling used)

    Returns
    -------
    samples : torch.Tensor
        samples `(n_samples, batch_size, output_shape)`
    """
    if pred_type not in PredType.__members__.values():
        raise ValueError("Only glm and nn supported as prediction types.")

    if pred_type == PredType.GLM:
        f_mu, f_var = self._glm_predictive_distribution(x)
        return self._glm_predictive_samples(
            f_mu, f_var, n_samples, diagonal_output, generator
        )

    else:  # 'nn'
        return self._nn_predictive_samples(x, n_samples, generator)

KronLaplace #

KronLaplace(model: Module, likelihood: Likelihood | str, sigma_noise: float | Tensor = 1.0, prior_precision: float | Tensor = 1.0, prior_mean: float | Tensor = 0.0, temperature: float = 1.0, enable_backprop: bool = False, dict_key_x: str = 'input_ids', dict_key_y: str = 'labels', backend: type[CurvatureInterface] | None = None, damping: bool = False, backend_kwargs: dict[str, Any] | None = None, asdl_fisher_kwargs: dict[str, Any] | None = None)

Bases: ParametricLaplace

Laplace approximation with Kronecker factored log likelihood Hessian approximation and hence posterior precision. Mathematically, we have for each parameter group, e.g., torch.nn.Module, that \P\approx Q \otimes H. See BaseLaplace for the full interface and see laplace.utils.matrix.Kron and laplace.utils.matrix.KronDecomposed for the structure of the Kronecker factors. Kron is used to aggregate factors by summing up and KronDecomposed is used to add the prior, a Hessian factor (e.g. temperature), and computing posterior covariances, marginal likelihood, etc. Damping can be enabled by setting damping=True.

Source code in laplace/baselaplace.py
def __init__(
    self,
    model: nn.Module,
    likelihood: Likelihood | str,
    sigma_noise: float | torch.Tensor = 1.0,
    prior_precision: float | torch.Tensor = 1.0,
    prior_mean: float | torch.Tensor = 0.0,
    temperature: float = 1.0,
    enable_backprop: bool = False,
    dict_key_x: str = "input_ids",
    dict_key_y: str = "labels",
    backend: type[CurvatureInterface] | None = None,
    damping: bool = False,
    backend_kwargs: dict[str, Any] | None = None,
    asdl_fisher_kwargs: dict[str, Any] | None = None,
):
    self.damping: bool = damping
    self.H_facs: Kron | None = None
    super().__init__(
        model,
        likelihood,
        sigma_noise,
        prior_precision,
        prior_mean,
        temperature,
        enable_backprop,
        dict_key_x,
        dict_key_y,
        backend,
        backend_kwargs,
        asdl_fisher_kwargs,
    )

log_likelihood #

log_likelihood: Tensor

Compute log likelihood on the training data after .fit() has been called. The log likelihood is computed on-demand based on the loss and, for example, the observation noise which makes it differentiable in the latter for iterative updates.

Returns:

  • log_likelihood ( Tensor ) –

prior_precision_diag #

prior_precision_diag: Tensor

Obtain the diagonal prior precision \(p_0\) constructed from either a scalar, layer-wise, or diagonal prior precision.

Returns:

  • prior_precision_diag ( Tensor ) –

scatter #

scatter: Tensor

Computes the scatter, a term of the log marginal likelihood that corresponds to L-2 regularization: scatter = \((\theta_{MAP} - \mu_0)^{T} P_0 (\theta_{MAP} - \mu_0) \).

Returns:

  • scatter ( Tensor ) –

log_det_prior_precision #

log_det_prior_precision: Tensor

Compute log determinant of the prior precision \(\log \det P_0\)

Returns:

  • log_det ( Tensor ) –

log_det_ratio #

log_det_ratio: Tensor

Compute the log determinant ratio, a part of the log marginal likelihood.

\[ \log \frac{\det P}{\det P_0} = \log \det P - \log \det P_0 \]

Returns:

  • log_det_ratio ( Tensor ) –

posterior_precision #

posterior_precision: KronDecomposed

Kronecker factored Posterior precision \(P\).

Returns:

  • precision ( `laplace.utils.matrix.KronDecomposed` ) –

log_marginal_likelihood #

log_marginal_likelihood(prior_precision: Tensor | None = None, sigma_noise: Tensor | None = None) -> Tensor

Compute the Laplace approximation to the log marginal likelihood subject to specific Hessian approximations that subclasses implement. Requires that the Laplace approximation has been fit before. The resulting torch.Tensor is differentiable in prior_precision and sigma_noise if these have gradients enabled. By passing prior_precision or sigma_noise, the current value is overwritten. This is useful for iterating on the log marginal likelihood.

Parameters:

  • prior_precision (Tensor, default: None ) –

    prior precision if should be changed from current prior_precision value

  • sigma_noise (Tensor, default: None ) –

    observation noise standard deviation if should be changed

Returns:

  • log_marglik ( Tensor ) –
Source code in laplace/baselaplace.py
def log_marginal_likelihood(
    self,
    prior_precision: torch.Tensor | None = None,
    sigma_noise: torch.Tensor | None = None,
) -> torch.Tensor:
    """Compute the Laplace approximation to the log marginal likelihood subject
    to specific Hessian approximations that subclasses implement.
    Requires that the Laplace approximation has been fit before.
    The resulting torch.Tensor is differentiable in `prior_precision` and
    `sigma_noise` if these have gradients enabled.
    By passing `prior_precision` or `sigma_noise`, the current value is
    overwritten. This is useful for iterating on the log marginal likelihood.

    Parameters
    ----------
    prior_precision : torch.Tensor, optional
        prior precision if should be changed from current `prior_precision` value
    sigma_noise : torch.Tensor, optional
        observation noise standard deviation if should be changed

    Returns
    -------
    log_marglik : torch.Tensor
    """
    # update prior precision (useful when iterating on marglik)
    if prior_precision is not None:
        self.prior_precision = prior_precision

    # update sigma_noise (useful when iterating on marglik)
    if sigma_noise is not None:
        if self.likelihood != Likelihood.REGRESSION:
            raise ValueError("Can only change sigma_noise for regression.")

        self.sigma_noise = sigma_noise

    return self.log_likelihood - 0.5 * (self.log_det_ratio + self.scatter)

__call__ #

__call__(x: Tensor | MutableMapping[str, Tensor | Any], pred_type: PredType | str = PredType.GLM, joint: bool = False, link_approx: LinkApprox | str = LinkApprox.PROBIT, n_samples: int = 100, diagonal_output: bool = False, generator: Generator | None = None, fitting: bool = False, **model_kwargs: dict[str, Any]) -> Tensor | tuple[Tensor, Tensor]

Compute the posterior predictive on input data x.

Parameters:

  • x (Tensor or MutableMapping) –

    (batch_size, input_shape) if tensor. If MutableMapping, must contain the said tensor.

  • pred_type (('glm', 'nn'), default: 'glm' ) –

    type of posterior predictive, linearized GLM predictive or neural network sampling predictive. The GLM predictive is consistent with the curvature approximations used here. When Laplace is done only on subset of parameters (i.e. some grad are disabled), only nn predictive is supported.

  • link_approx (('mc', 'probit', 'bridge', 'bridge_norm'), default: 'mc' ) –

    how to approximate the classification link function for the 'glm'. For pred_type='nn', only 'mc' is possible.

  • joint (bool, default: False ) –

    Whether to output a joint predictive distribution in regression with pred_type='glm'. If set to True, the predictive distribution has the same form as GP posterior, i.e. N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]). If False, then only outputs the marginal predictive distribution. Only available for regression and GLM predictive.

  • n_samples (int, default: 100 ) –

    number of samples for link_approx='mc'.

  • diagonal_output (bool, default: False ) –

    whether to use a diagonalized posterior predictive on the outputs. Only works for pred_type='glm' when joint=False in regression. In the case of last-layer Laplace with a diagonal or Kron Hessian, setting this to True makes computation much(!) faster for large number of outputs.

  • generator (Generator, default: None ) –

    random number generator to control the samples (if sampling used).

  • fitting (bool, default: False ) –

    whether or not this predictive call is done during fitting. Only useful for reward modeling: the likelihood is set to "regression" when False and "classification" when True.

Returns:

  • predictive ( Tensor or tuple[Tensor] ) –

    For likelihood='classification', a torch.Tensor is returned with a distribution over classes (similar to a Softmax). For likelihood='regression', a tuple of torch.Tensor is returned with the mean and the predictive variance. For likelihood='regression' and joint=True, a tuple of torch.Tensor is returned with the mean and the predictive covariance.

Source code in laplace/baselaplace.py
def __call__(
    self,
    x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
    pred_type: PredType | str = PredType.GLM,
    joint: bool = False,
    link_approx: LinkApprox | str = LinkApprox.PROBIT,
    n_samples: int = 100,
    diagonal_output: bool = False,
    generator: torch.Generator | None = None,
    fitting: bool = False,
    **model_kwargs: dict[str, Any],
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
    """Compute the posterior predictive on input data `x`.

    Parameters
    ----------
    x : torch.Tensor or MutableMapping
        `(batch_size, input_shape)` if tensor. If MutableMapping, must contain
        the said tensor.

    pred_type : {'glm', 'nn'}, default='glm'
        type of posterior predictive, linearized GLM predictive or neural
        network sampling predictive. The GLM predictive is consistent with
        the curvature approximations used here. When Laplace is done only
        on subset of parameters (i.e. some grad are disabled),
        only `nn` predictive is supported.

    link_approx : {'mc', 'probit', 'bridge', 'bridge_norm'}
        how to approximate the classification link function for the `'glm'`.
        For `pred_type='nn'`, only 'mc' is possible.

    joint : bool
        Whether to output a joint predictive distribution in regression with
        `pred_type='glm'`. If set to `True`, the predictive distribution
        has the same form as GP posterior, i.e. N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]).
        If `False`, then only outputs the marginal predictive distribution.
        Only available for regression and GLM predictive.

    n_samples : int
        number of samples for `link_approx='mc'`.

    diagonal_output : bool
        whether to use a diagonalized posterior predictive on the outputs.
        Only works for `pred_type='glm'` when `joint=False` in regression.
        In the case of last-layer Laplace with a diagonal or Kron Hessian,
        setting this to `True` makes computation much(!) faster for large
        number of outputs.

    generator : torch.Generator, optional
        random number generator to control the samples (if sampling used).

    fitting : bool, default=False
        whether or not this predictive call is done during fitting. Only useful for
        reward modeling: the likelihood is set to `"regression"` when `False` and
        `"classification"` when `True`.

    Returns
    -------
    predictive: torch.Tensor or tuple[torch.Tensor]
        For `likelihood='classification'`, a torch.Tensor is returned with
        a distribution over classes (similar to a Softmax).
        For `likelihood='regression'`, a tuple of torch.Tensor is returned
        with the mean and the predictive variance.
        For `likelihood='regression'` and `joint=True`, a tuple of torch.Tensor
        is returned with the mean and the predictive covariance.
    """
    if pred_type not in [pred for pred in PredType]:
        raise ValueError("Only glm and nn supported as prediction types.")

    if link_approx not in [la for la in LinkApprox]:
        raise ValueError(f"Unsupported link approximation {link_approx}.")

    if pred_type == PredType.NN and link_approx != LinkApprox.MC:
        raise ValueError(
            "Only mc link approximation is supported for nn prediction type."
        )

    if generator is not None:
        if (
            not isinstance(generator, torch.Generator)
            or generator.device != self._device
        ):
            raise ValueError("Invalid random generator (check type and device).")

    likelihood = self.likelihood
    if likelihood == Likelihood.REWARD_MODELING:
        likelihood = Likelihood.CLASSIFICATION if fitting else Likelihood.REGRESSION

    if pred_type == PredType.GLM:
        return self._glm_forward_call(
            x, likelihood, joint, link_approx, n_samples, diagonal_output
        )
    else:
        if likelihood == Likelihood.REGRESSION:
            samples = self._nn_predictive_samples(x, n_samples, **model_kwargs)
            return samples.mean(dim=0), samples.var(dim=0)
        else:  # classification; the average is computed online
            return self._nn_predictive_classification(x, n_samples, **model_kwargs)

_glm_forward_call #

_glm_forward_call(x: Tensor | MutableMapping, likelihood: Likelihood | str, joint: bool = False, link_approx: LinkApprox | str = LinkApprox.PROBIT, n_samples: int = 100, diagonal_output: bool = False) -> Tensor | tuple[Tensor, Tensor]

Compute the posterior predictive on input data x for "glm" pred type.

Parameters:

  • x (Tensor or MutableMapping) –

    (batch_size, input_shape) if tensor. If MutableMapping, must contain the said tensor.

  • likelihood (Likelihood or str in {'classification', 'regression', 'reward_modeling'}) –

    determines the log likelihood Hessian approximation.

  • link_approx (('mc', 'probit', 'bridge', 'bridge_norm'), default: 'mc' ) –

    how to approximate the classification link function for the 'glm'. For pred_type='nn', only 'mc' is possible.

  • joint (bool, default: False ) –

    Whether to output a joint predictive distribution in regression with pred_type='glm'. If set to True, the predictive distribution has the same form as GP posterior, i.e. N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]). If False, then only outputs the marginal predictive distribution. Only available for regression and GLM predictive.

  • n_samples (int, default: 100 ) –

    number of samples for link_approx='mc'.

  • diagonal_output (bool, default: False ) –

    whether to use a diagonalized posterior predictive on the outputs. Only works for pred_type='glm' and link_approx='mc'.

Returns:

  • predictive ( Tensor or tuple[Tensor] ) –

    For likelihood='classification', a torch.Tensor is returned with a distribution over classes (similar to a Softmax). For likelihood='regression', a tuple of torch.Tensor is returned with the mean and the predictive variance. For likelihood='regression' and joint=True, a tuple of torch.Tensor is returned with the mean and the predictive covariance.

Source code in laplace/baselaplace.py
def _glm_forward_call(
    self,
    x: torch.Tensor | MutableMapping,
    likelihood: Likelihood | str,
    joint: bool = False,
    link_approx: LinkApprox | str = LinkApprox.PROBIT,
    n_samples: int = 100,
    diagonal_output: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
    """Compute the posterior predictive on input data `x` for "glm" pred type.

    Parameters
    ----------
    x : torch.Tensor or MutableMapping
        `(batch_size, input_shape)` if tensor. If MutableMapping, must contain
        the said tensor.

    likelihood : Likelihood or str in {'classification', 'regression', 'reward_modeling'}
        determines the log likelihood Hessian approximation.

    link_approx : {'mc', 'probit', 'bridge', 'bridge_norm'}
        how to approximate the classification link function for the `'glm'`.
        For `pred_type='nn'`, only 'mc' is possible.

    joint : bool
        Whether to output a joint predictive distribution in regression with
        `pred_type='glm'`. If set to `True`, the predictive distribution
        has the same form as GP posterior, i.e. N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]).
        If `False`, then only outputs the marginal predictive distribution.
        Only available for regression and GLM predictive.

    n_samples : int
        number of samples for `link_approx='mc'`.

    diagonal_output : bool
        whether to use a diagonalized posterior predictive on the outputs.
        Only works for `pred_type='glm'` and `link_approx='mc'`.

    Returns
    -------
    predictive: torch.Tensor or tuple[torch.Tensor]
        For `likelihood='classification'`, a torch.Tensor is returned with
        a distribution over classes (similar to a Softmax).
        For `likelihood='regression'`, a tuple of torch.Tensor is returned
        with the mean and the predictive variance.
        For `likelihood='regression'` and `joint=True`, a tuple of torch.Tensor
        is returned with the mean and the predictive covariance.
    """
    f_mu, f_var = self._glm_predictive_distribution(
        x, joint=joint and likelihood == Likelihood.REGRESSION
    )

    if likelihood == Likelihood.REGRESSION:
        if diagonal_output and not joint:
            f_var = torch.diagonal(f_var, dim1=-2, dim2=-1)
        return f_mu, f_var

    if link_approx == LinkApprox.MC:
        return self._glm_predictive_samples(
            f_mu,
            f_var,
            n_samples=n_samples,
            diagonal_output=diagonal_output,
        ).mean(dim=0)
    elif link_approx == LinkApprox.PROBIT:
        kappa = 1 / torch.sqrt(1.0 + np.pi / 8 * f_var.diagonal(dim1=1, dim2=2))
        return torch.softmax(kappa * f_mu, dim=-1)
    elif "bridge" in link_approx:
        # zero mean correction
        f_mu -= (
            f_var.sum(-1)
            * f_mu.sum(-1).reshape(-1, 1)
            / f_var.sum(dim=(1, 2)).reshape(-1, 1)
        )
        f_var -= torch.einsum(
            "bi,bj->bij", f_var.sum(-1), f_var.sum(-2)
        ) / f_var.sum(dim=(1, 2)).reshape(-1, 1, 1)

        # Laplace Bridge
        _, K = f_mu.size(0), f_mu.size(-1)
        f_var_diag = torch.diagonal(f_var, dim1=1, dim2=2)

        # optional: variance correction
        if link_approx == LinkApprox.BRIDGE_NORM:
            f_var_diag_mean = f_var_diag.mean(dim=1)
            f_var_diag_mean /= torch.as_tensor([K / 2], device=self._device).sqrt()
            f_mu /= f_var_diag_mean.sqrt().unsqueeze(-1)
            f_var_diag /= f_var_diag_mean.unsqueeze(-1)

        sum_exp = torch.exp(-f_mu).sum(dim=1).unsqueeze(-1)
        alpha = (1 - 2 / K + f_mu.exp() / K**2 * sum_exp) / f_var_diag
        return torch.nan_to_num(alpha / alpha.sum(dim=1).unsqueeze(-1), nan=1.0)
    else:
        raise ValueError(
            "Prediction path invalid. Check the likelihood, pred_type, link_approx combination!"
        )

_glm_predictive_samples #

_glm_predictive_samples(f_mu: Tensor, f_var: Tensor, n_samples: int, diagonal_output: bool = False, generator: Generator | None = None) -> Tensor

Sample from the posterior predictive on input data x using "glm" prediction type.

Parameters:

  • f_mu (Tensor or MutableMapping) –

    glm predictive mean (batch_size, output_shape)

  • f_var (Tensor or MutableMapping) –

    glm predictive covariances (batch_size, output_shape, output_shape)

  • n_samples (int) –

    number of samples

  • diagonal_output (bool, default: False ) –

    whether to use a diagonalized glm posterior predictive on the outputs.

  • generator (Generator, default: None ) –

    random number generator to control the samples (if sampling used)

Returns:

  • samples ( Tensor ) –

    samples (n_samples, batch_size, output_shape)

Source code in laplace/baselaplace.py
def _glm_predictive_samples(
    self,
    f_mu: torch.Tensor,
    f_var: torch.Tensor,
    n_samples: int,
    diagonal_output: bool = False,
    generator: torch.Generator | None = None,
) -> torch.Tensor:
    """Sample from the posterior predictive on input data `x` using "glm" prediction
    type.

    Parameters
    ----------
    f_mu : torch.Tensor or MutableMapping
        glm predictive mean `(batch_size, output_shape)`

    f_var : torch.Tensor or MutableMapping
        glm predictive covariances `(batch_size, output_shape, output_shape)`

    n_samples : int
        number of samples

    diagonal_output : bool
        whether to use a diagonalized glm posterior predictive on the outputs.

    generator : torch.Generator, optional
        random number generator to control the samples (if sampling used)

    Returns
    -------
    samples : torch.Tensor
        samples `(n_samples, batch_size, output_shape)`
    """
    assert f_var.shape == torch.Size([f_mu.shape[0], f_mu.shape[1], f_mu.shape[1]])

    if diagonal_output:
        f_var = torch.diagonal(f_var, dim1=1, dim2=2)

    f_samples = normal_samples(f_mu, f_var, n_samples, generator)

    if self.likelihood == Likelihood.REGRESSION:
        return f_samples
    else:
        return torch.softmax(f_samples, dim=-1)

log_prob #

log_prob(value: Tensor, normalized: bool = True) -> Tensor

Compute the log probability under the (current) Laplace approximation.

Parameters:

  • value (Tensor) –
  • normalized (bool, default: True ) –

    whether to return log of a properly normalized Gaussian or just the terms that depend on value.

Returns:

  • log_prob ( Tensor ) –
Source code in laplace/baselaplace.py
def log_prob(self, value: torch.Tensor, normalized: bool = True) -> torch.Tensor:
    """Compute the log probability under the (current) Laplace approximation.

    Parameters
    ----------
    value: torch.Tensor
    normalized : bool, default=True
        whether to return log of a properly normalized Gaussian or just the
        terms that depend on `value`.

    Returns
    -------
    log_prob : torch.Tensor
    """
    if not normalized:
        return -self.square_norm(value) / 2
    log_prob = (
        -self.n_params / 2 * log(2 * pi) + self.log_det_posterior_precision / 2
    )
    log_prob -= self.square_norm(value) / 2
    return log_prob

predictive_samples #

predictive_samples(x: Tensor | MutableMapping[str, Tensor | Any], pred_type: PredType | str = PredType.GLM, n_samples: int = 100, diagonal_output: bool = False, generator: Generator | None = None) -> Tensor

Sample from the posterior predictive on input data x. Can be used, for example, for Thompson sampling.

Parameters:

  • x (Tensor or MutableMapping) –

    input data (batch_size, input_shape)

  • pred_type (('glm', 'nn'), default: 'glm' ) –

    type of posterior predictive, linearized GLM predictive or neural network sampling predictive. The GLM predictive is consistent with the curvature approximations used here.

  • n_samples (int, default: 100 ) –

    number of samples

  • diagonal_output (bool, default: False ) –

    whether to use a diagonalized glm posterior predictive on the outputs. Only applies when pred_type='glm'.

  • generator (Generator, default: None ) –

    random number generator to control the samples (if sampling used)

Returns:

  • samples ( Tensor ) –

    samples (n_samples, batch_size, output_shape)

Source code in laplace/baselaplace.py
def predictive_samples(
    self,
    x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
    pred_type: PredType | str = PredType.GLM,
    n_samples: int = 100,
    diagonal_output: bool = False,
    generator: torch.Generator | None = None,
) -> torch.Tensor:
    """Sample from the posterior predictive on input data `x`.
    Can be used, for example, for Thompson sampling.

    Parameters
    ----------
    x : torch.Tensor or MutableMapping
        input data `(batch_size, input_shape)`

    pred_type : {'glm', 'nn'}, default='glm'
        type of posterior predictive, linearized GLM predictive or neural
        network sampling predictive. The GLM predictive is consistent with
        the curvature approximations used here.

    n_samples : int
        number of samples

    diagonal_output : bool
        whether to use a diagonalized glm posterior predictive on the outputs.
        Only applies when `pred_type='glm'`.

    generator : torch.Generator, optional
        random number generator to control the samples (if sampling used)

    Returns
    -------
    samples : torch.Tensor
        samples `(n_samples, batch_size, output_shape)`
    """
    if pred_type not in PredType.__members__.values():
        raise ValueError("Only glm and nn supported as prediction types.")

    if pred_type == PredType.GLM:
        f_mu, f_var = self._glm_predictive_distribution(x)
        return self._glm_predictive_samples(
            f_mu, f_var, n_samples, diagonal_output, generator
        )

    else:  # 'nn'
        return self._nn_predictive_samples(x, n_samples, generator)

LowRankLaplace #

LowRankLaplace(model: Module, likelihood: Likelihood | str, backend: type[CurvatureInterface] = AsdfghjklHessian if find_spec('asdfghjkl') is not None else CurvatureInterface, sigma_noise: float | Tensor = 1, prior_precision: float | Tensor = 1, prior_mean: float | Tensor = 0, temperature: float = 1, enable_backprop: bool = False, dict_key_x: str = 'input_ids', dict_key_y: str = 'labels', backend_kwargs: dict[str, Any] | None = None)

Bases: ParametricLaplace

Laplace approximation with low-rank log likelihood Hessian (approximation). The low-rank matrix is represented by an eigendecomposition (vecs, values). Based on the chosen backend, either a true Hessian or, for example, GGN approximation could be used. The posterior precision is computed as \( P = V diag(l) V^T + P_0.\) To sample, compute the functional variance, and log determinant, algebraic tricks are usedto reduce the costs of inversion to the that of a \(K imes K\) matrix if we have a rank of K.

Note that only AsdfghjklHessian backend is supported. Install it via: pip install git+https://git@github.com/wiseodd/asdl@asdfghjkl

See BaseLaplace for the full interface.

Source code in laplace/baselaplace.py
def __init__(
    self,
    model: nn.Module,
    likelihood: Likelihood | str,
    backend: type[CurvatureInterface] = AsdfghjklHessian
    if find_spec("asdfghjkl") is not None
    else CurvatureInterface,
    sigma_noise: float | torch.Tensor = 1,
    prior_precision: float | torch.Tensor = 1,
    prior_mean: float | torch.Tensor = 0,
    temperature: float = 1,
    enable_backprop: bool = False,
    dict_key_x: str = "input_ids",
    dict_key_y: str = "labels",
    backend_kwargs: dict[str, Any] | None = None,
):
    if find_spec("asdfghjkl") is None:
        raise ImportError(
            """To use LowRankLaplace, please install the old asdfghjkl dependency: """
            """pip install git+https://git@github.com/wiseodd/asdl@asdfghjkl"""
        )

    super().__init__(
        model,
        likelihood,
        sigma_noise=sigma_noise,
        prior_precision=prior_precision,
        prior_mean=prior_mean,
        temperature=temperature,
        enable_backprop=enable_backprop,
        dict_key_x=dict_key_x,
        dict_key_y=dict_key_y,
        backend=backend,
        backend_kwargs=backend_kwargs,
    )

log_likelihood #

log_likelihood: Tensor

Compute log likelihood on the training data after .fit() has been called. The log likelihood is computed on-demand based on the loss and, for example, the observation noise which makes it differentiable in the latter for iterative updates.

Returns:

  • log_likelihood ( Tensor ) –

prior_precision_diag #

prior_precision_diag: Tensor

Obtain the diagonal prior precision \(p_0\) constructed from either a scalar, layer-wise, or diagonal prior precision.

Returns:

  • prior_precision_diag ( Tensor ) –

scatter #

scatter: Tensor

Computes the scatter, a term of the log marginal likelihood that corresponds to L-2 regularization: scatter = \((\theta_{MAP} - \mu_0)^{T} P_0 (\theta_{MAP} - \mu_0) \).

Returns:

  • scatter ( Tensor ) –

log_det_prior_precision #

log_det_prior_precision: Tensor

Compute log determinant of the prior precision \(\log \det P_0\)

Returns:

  • log_det ( Tensor ) –

log_det_ratio #

log_det_ratio: Tensor

Compute the log determinant ratio, a part of the log marginal likelihood.

\[ \log \frac{\det P}{\det P_0} = \log \det P - \log \det P_0 \]

Returns:

  • log_det_ratio ( Tensor ) –

posterior_precision #

posterior_precision: tuple[tuple[Tensor, Tensor], Tensor]

Return correctly scaled posterior precision that would be constructed as H[0] @ diag(H[1]) @ H[0].T + self.prior_precision_diag.

Returns:

  • H ( tuple(eigenvectors, eigenvalues) ) –

    scaled self.H with temperature and loss factors.

  • prior_precision_diag ( Tensor ) –

    diagonal prior precision shape parameters to be added to H.

log_marginal_likelihood #

log_marginal_likelihood(prior_precision: Tensor | None = None, sigma_noise: Tensor | None = None) -> Tensor

Compute the Laplace approximation to the log marginal likelihood subject to specific Hessian approximations that subclasses implement. Requires that the Laplace approximation has been fit before. The resulting torch.Tensor is differentiable in prior_precision and sigma_noise if these have gradients enabled. By passing prior_precision or sigma_noise, the current value is overwritten. This is useful for iterating on the log marginal likelihood.

Parameters:

  • prior_precision (Tensor, default: None ) –

    prior precision if should be changed from current prior_precision value

  • sigma_noise (Tensor, default: None ) –

    observation noise standard deviation if should be changed

Returns:

  • log_marglik ( Tensor ) –
Source code in laplace/baselaplace.py
def log_marginal_likelihood(
    self,
    prior_precision: torch.Tensor | None = None,
    sigma_noise: torch.Tensor | None = None,
) -> torch.Tensor:
    """Compute the Laplace approximation to the log marginal likelihood subject
    to specific Hessian approximations that subclasses implement.
    Requires that the Laplace approximation has been fit before.
    The resulting torch.Tensor is differentiable in `prior_precision` and
    `sigma_noise` if these have gradients enabled.
    By passing `prior_precision` or `sigma_noise`, the current value is
    overwritten. This is useful for iterating on the log marginal likelihood.

    Parameters
    ----------
    prior_precision : torch.Tensor, optional
        prior precision if should be changed from current `prior_precision` value
    sigma_noise : torch.Tensor, optional
        observation noise standard deviation if should be changed

    Returns
    -------
    log_marglik : torch.Tensor
    """
    # update prior precision (useful when iterating on marglik)
    if prior_precision is not None:
        self.prior_precision = prior_precision

    # update sigma_noise (useful when iterating on marglik)
    if sigma_noise is not None:
        if self.likelihood != Likelihood.REGRESSION:
            raise ValueError("Can only change sigma_noise for regression.")

        self.sigma_noise = sigma_noise

    return self.log_likelihood - 0.5 * (self.log_det_ratio + self.scatter)

__call__ #

__call__(x: Tensor | MutableMapping[str, Tensor | Any], pred_type: PredType | str = PredType.GLM, joint: bool = False, link_approx: LinkApprox | str = LinkApprox.PROBIT, n_samples: int = 100, diagonal_output: bool = False, generator: Generator | None = None, fitting: bool = False, **model_kwargs: dict[str, Any]) -> Tensor | tuple[Tensor, Tensor]

Compute the posterior predictive on input data x.

Parameters:

  • x (Tensor or MutableMapping) –

    (batch_size, input_shape) if tensor. If MutableMapping, must contain the said tensor.

  • pred_type (('glm', 'nn'), default: 'glm' ) –

    type of posterior predictive, linearized GLM predictive or neural network sampling predictive. The GLM predictive is consistent with the curvature approximations used here. When Laplace is done only on subset of parameters (i.e. some grad are disabled), only nn predictive is supported.

  • link_approx (('mc', 'probit', 'bridge', 'bridge_norm'), default: 'mc' ) –

    how to approximate the classification link function for the 'glm'. For pred_type='nn', only 'mc' is possible.

  • joint (bool, default: False ) –

    Whether to output a joint predictive distribution in regression with pred_type='glm'. If set to True, the predictive distribution has the same form as GP posterior, i.e. N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]). If False, then only outputs the marginal predictive distribution. Only available for regression and GLM predictive.

  • n_samples (int, default: 100 ) –

    number of samples for link_approx='mc'.

  • diagonal_output (bool, default: False ) –

    whether to use a diagonalized posterior predictive on the outputs. Only works for pred_type='glm' when joint=False in regression. In the case of last-layer Laplace with a diagonal or Kron Hessian, setting this to True makes computation much(!) faster for large number of outputs.

  • generator (Generator, default: None ) –

    random number generator to control the samples (if sampling used).

  • fitting (bool, default: False ) –

    whether or not this predictive call is done during fitting. Only useful for reward modeling: the likelihood is set to "regression" when False and "classification" when True.

Returns:

  • predictive ( Tensor or tuple[Tensor] ) –

    For likelihood='classification', a torch.Tensor is returned with a distribution over classes (similar to a Softmax). For likelihood='regression', a tuple of torch.Tensor is returned with the mean and the predictive variance. For likelihood='regression' and joint=True, a tuple of torch.Tensor is returned with the mean and the predictive covariance.

Source code in laplace/baselaplace.py
def __call__(
    self,
    x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
    pred_type: PredType | str = PredType.GLM,
    joint: bool = False,
    link_approx: LinkApprox | str = LinkApprox.PROBIT,
    n_samples: int = 100,
    diagonal_output: bool = False,
    generator: torch.Generator | None = None,
    fitting: bool = False,
    **model_kwargs: dict[str, Any],
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
    """Compute the posterior predictive on input data `x`.

    Parameters
    ----------
    x : torch.Tensor or MutableMapping
        `(batch_size, input_shape)` if tensor. If MutableMapping, must contain
        the said tensor.

    pred_type : {'glm', 'nn'}, default='glm'
        type of posterior predictive, linearized GLM predictive or neural
        network sampling predictive. The GLM predictive is consistent with
        the curvature approximations used here. When Laplace is done only
        on subset of parameters (i.e. some grad are disabled),
        only `nn` predictive is supported.

    link_approx : {'mc', 'probit', 'bridge', 'bridge_norm'}
        how to approximate the classification link function for the `'glm'`.
        For `pred_type='nn'`, only 'mc' is possible.

    joint : bool
        Whether to output a joint predictive distribution in regression with
        `pred_type='glm'`. If set to `True`, the predictive distribution
        has the same form as GP posterior, i.e. N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]).
        If `False`, then only outputs the marginal predictive distribution.
        Only available for regression and GLM predictive.

    n_samples : int
        number of samples for `link_approx='mc'`.

    diagonal_output : bool
        whether to use a diagonalized posterior predictive on the outputs.
        Only works for `pred_type='glm'` when `joint=False` in regression.
        In the case of last-layer Laplace with a diagonal or Kron Hessian,
        setting this to `True` makes computation much(!) faster for large
        number of outputs.

    generator : torch.Generator, optional
        random number generator to control the samples (if sampling used).

    fitting : bool, default=False
        whether or not this predictive call is done during fitting. Only useful for
        reward modeling: the likelihood is set to `"regression"` when `False` and
        `"classification"` when `True`.

    Returns
    -------
    predictive: torch.Tensor or tuple[torch.Tensor]
        For `likelihood='classification'`, a torch.Tensor is returned with
        a distribution over classes (similar to a Softmax).
        For `likelihood='regression'`, a tuple of torch.Tensor is returned
        with the mean and the predictive variance.
        For `likelihood='regression'` and `joint=True`, a tuple of torch.Tensor
        is returned with the mean and the predictive covariance.
    """
    if pred_type not in [pred for pred in PredType]:
        raise ValueError("Only glm and nn supported as prediction types.")

    if link_approx not in [la for la in LinkApprox]:
        raise ValueError(f"Unsupported link approximation {link_approx}.")

    if pred_type == PredType.NN and link_approx != LinkApprox.MC:
        raise ValueError(
            "Only mc link approximation is supported for nn prediction type."
        )

    if generator is not None:
        if (
            not isinstance(generator, torch.Generator)
            or generator.device != self._device
        ):
            raise ValueError("Invalid random generator (check type and device).")

    likelihood = self.likelihood
    if likelihood == Likelihood.REWARD_MODELING:
        likelihood = Likelihood.CLASSIFICATION if fitting else Likelihood.REGRESSION

    if pred_type == PredType.GLM:
        return self._glm_forward_call(
            x, likelihood, joint, link_approx, n_samples, diagonal_output
        )
    else:
        if likelihood == Likelihood.REGRESSION:
            samples = self._nn_predictive_samples(x, n_samples, **model_kwargs)
            return samples.mean(dim=0), samples.var(dim=0)
        else:  # classification; the average is computed online
            return self._nn_predictive_classification(x, n_samples, **model_kwargs)

_glm_forward_call #

_glm_forward_call(x: Tensor | MutableMapping, likelihood: Likelihood | str, joint: bool = False, link_approx: LinkApprox | str = LinkApprox.PROBIT, n_samples: int = 100, diagonal_output: bool = False) -> Tensor | tuple[Tensor, Tensor]

Compute the posterior predictive on input data x for "glm" pred type.

Parameters:

  • x (Tensor or MutableMapping) –

    (batch_size, input_shape) if tensor. If MutableMapping, must contain the said tensor.

  • likelihood (Likelihood or str in {'classification', 'regression', 'reward_modeling'}) –

    determines the log likelihood Hessian approximation.

  • link_approx (('mc', 'probit', 'bridge', 'bridge_norm'), default: 'mc' ) –

    how to approximate the classification link function for the 'glm'. For pred_type='nn', only 'mc' is possible.

  • joint (bool, default: False ) –

    Whether to output a joint predictive distribution in regression with pred_type='glm'. If set to True, the predictive distribution has the same form as GP posterior, i.e. N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]). If False, then only outputs the marginal predictive distribution. Only available for regression and GLM predictive.

  • n_samples (int, default: 100 ) –

    number of samples for link_approx='mc'.

  • diagonal_output (bool, default: False ) –

    whether to use a diagonalized posterior predictive on the outputs. Only works for pred_type='glm' and link_approx='mc'.

Returns:

  • predictive ( Tensor or tuple[Tensor] ) –

    For likelihood='classification', a torch.Tensor is returned with a distribution over classes (similar to a Softmax). For likelihood='regression', a tuple of torch.Tensor is returned with the mean and the predictive variance. For likelihood='regression' and joint=True, a tuple of torch.Tensor is returned with the mean and the predictive covariance.

Source code in laplace/baselaplace.py
def _glm_forward_call(
    self,
    x: torch.Tensor | MutableMapping,
    likelihood: Likelihood | str,
    joint: bool = False,
    link_approx: LinkApprox | str = LinkApprox.PROBIT,
    n_samples: int = 100,
    diagonal_output: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
    """Compute the posterior predictive on input data `x` for "glm" pred type.

    Parameters
    ----------
    x : torch.Tensor or MutableMapping
        `(batch_size, input_shape)` if tensor. If MutableMapping, must contain
        the said tensor.

    likelihood : Likelihood or str in {'classification', 'regression', 'reward_modeling'}
        determines the log likelihood Hessian approximation.

    link_approx : {'mc', 'probit', 'bridge', 'bridge_norm'}
        how to approximate the classification link function for the `'glm'`.
        For `pred_type='nn'`, only 'mc' is possible.

    joint : bool
        Whether to output a joint predictive distribution in regression with
        `pred_type='glm'`. If set to `True`, the predictive distribution
        has the same form as GP posterior, i.e. N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]).
        If `False`, then only outputs the marginal predictive distribution.
        Only available for regression and GLM predictive.

    n_samples : int
        number of samples for `link_approx='mc'`.

    diagonal_output : bool
        whether to use a diagonalized posterior predictive on the outputs.
        Only works for `pred_type='glm'` and `link_approx='mc'`.

    Returns
    -------
    predictive: torch.Tensor or tuple[torch.Tensor]
        For `likelihood='classification'`, a torch.Tensor is returned with
        a distribution over classes (similar to a Softmax).
        For `likelihood='regression'`, a tuple of torch.Tensor is returned
        with the mean and the predictive variance.
        For `likelihood='regression'` and `joint=True`, a tuple of torch.Tensor
        is returned with the mean and the predictive covariance.
    """
    f_mu, f_var = self._glm_predictive_distribution(
        x, joint=joint and likelihood == Likelihood.REGRESSION
    )

    if likelihood == Likelihood.REGRESSION:
        if diagonal_output and not joint:
            f_var = torch.diagonal(f_var, dim1=-2, dim2=-1)
        return f_mu, f_var

    if link_approx == LinkApprox.MC:
        return self._glm_predictive_samples(
            f_mu,
            f_var,
            n_samples=n_samples,
            diagonal_output=diagonal_output,
        ).mean(dim=0)
    elif link_approx == LinkApprox.PROBIT:
        kappa = 1 / torch.sqrt(1.0 + np.pi / 8 * f_var.diagonal(dim1=1, dim2=2))
        return torch.softmax(kappa * f_mu, dim=-1)
    elif "bridge" in link_approx:
        # zero mean correction
        f_mu -= (
            f_var.sum(-1)
            * f_mu.sum(-1).reshape(-1, 1)
            / f_var.sum(dim=(1, 2)).reshape(-1, 1)
        )
        f_var -= torch.einsum(
            "bi,bj->bij", f_var.sum(-1), f_var.sum(-2)
        ) / f_var.sum(dim=(1, 2)).reshape(-1, 1, 1)

        # Laplace Bridge
        _, K = f_mu.size(0), f_mu.size(-1)
        f_var_diag = torch.diagonal(f_var, dim1=1, dim2=2)

        # optional: variance correction
        if link_approx == LinkApprox.BRIDGE_NORM:
            f_var_diag_mean = f_var_diag.mean(dim=1)
            f_var_diag_mean /= torch.as_tensor([K / 2], device=self._device).sqrt()
            f_mu /= f_var_diag_mean.sqrt().unsqueeze(-1)
            f_var_diag /= f_var_diag_mean.unsqueeze(-1)

        sum_exp = torch.exp(-f_mu).sum(dim=1).unsqueeze(-1)
        alpha = (1 - 2 / K + f_mu.exp() / K**2 * sum_exp) / f_var_diag
        return torch.nan_to_num(alpha / alpha.sum(dim=1).unsqueeze(-1), nan=1.0)
    else:
        raise ValueError(
            "Prediction path invalid. Check the likelihood, pred_type, link_approx combination!"
        )

_glm_predictive_samples #

_glm_predictive_samples(f_mu: Tensor, f_var: Tensor, n_samples: int, diagonal_output: bool = False, generator: Generator | None = None) -> Tensor

Sample from the posterior predictive on input data x using "glm" prediction type.

Parameters:

  • f_mu (Tensor or MutableMapping) –

    glm predictive mean (batch_size, output_shape)

  • f_var (Tensor or MutableMapping) –

    glm predictive covariances (batch_size, output_shape, output_shape)

  • n_samples (int) –

    number of samples

  • diagonal_output (bool, default: False ) –

    whether to use a diagonalized glm posterior predictive on the outputs.

  • generator (Generator, default: None ) –

    random number generator to control the samples (if sampling used)

Returns:

  • samples ( Tensor ) –

    samples (n_samples, batch_size, output_shape)

Source code in laplace/baselaplace.py
def _glm_predictive_samples(
    self,
    f_mu: torch.Tensor,
    f_var: torch.Tensor,
    n_samples: int,
    diagonal_output: bool = False,
    generator: torch.Generator | None = None,
) -> torch.Tensor:
    """Sample from the posterior predictive on input data `x` using "glm" prediction
    type.

    Parameters
    ----------
    f_mu : torch.Tensor or MutableMapping
        glm predictive mean `(batch_size, output_shape)`

    f_var : torch.Tensor or MutableMapping
        glm predictive covariances `(batch_size, output_shape, output_shape)`

    n_samples : int
        number of samples

    diagonal_output : bool
        whether to use a diagonalized glm posterior predictive on the outputs.

    generator : torch.Generator, optional
        random number generator to control the samples (if sampling used)

    Returns
    -------
    samples : torch.Tensor
        samples `(n_samples, batch_size, output_shape)`
    """
    assert f_var.shape == torch.Size([f_mu.shape[0], f_mu.shape[1], f_mu.shape[1]])

    if diagonal_output:
        f_var = torch.diagonal(f_var, dim1=1, dim2=2)

    f_samples = normal_samples(f_mu, f_var, n_samples, generator)

    if self.likelihood == Likelihood.REGRESSION:
        return f_samples
    else:
        return torch.softmax(f_samples, dim=-1)

square_norm #

square_norm(value) -> Tensor

Compute the square norm under post. Precision with value-self.mean as 𝛥:

\[ \Delta^ op P \Delta \]

Returns:

  • square_form
Source code in laplace/baselaplace.py
def square_norm(self, value) -> torch.Tensor:
    """Compute the square norm under post. Precision with `value-self.mean` as 𝛥:

    $$
        \\Delta^\top P \\Delta
    $$

    Returns
    -------
    square_form
    """
    raise NotImplementedError

log_prob #

log_prob(value: Tensor, normalized: bool = True) -> Tensor

Compute the log probability under the (current) Laplace approximation.

Parameters:

  • value (Tensor) –
  • normalized (bool, default: True ) –

    whether to return log of a properly normalized Gaussian or just the terms that depend on value.

Returns:

  • log_prob ( Tensor ) –
Source code in laplace/baselaplace.py
def log_prob(self, value: torch.Tensor, normalized: bool = True) -> torch.Tensor:
    """Compute the log probability under the (current) Laplace approximation.

    Parameters
    ----------
    value: torch.Tensor
    normalized : bool, default=True
        whether to return log of a properly normalized Gaussian or just the
        terms that depend on `value`.

    Returns
    -------
    log_prob : torch.Tensor
    """
    if not normalized:
        return -self.square_norm(value) / 2
    log_prob = (
        -self.n_params / 2 * log(2 * pi) + self.log_det_posterior_precision / 2
    )
    log_prob -= self.square_norm(value) / 2
    return log_prob

predictive_samples #

predictive_samples(x: Tensor | MutableMapping[str, Tensor | Any], pred_type: PredType | str = PredType.GLM, n_samples: int = 100, diagonal_output: bool = False, generator: Generator | None = None) -> Tensor

Sample from the posterior predictive on input data x. Can be used, for example, for Thompson sampling.

Parameters:

  • x (Tensor or MutableMapping) –

    input data (batch_size, input_shape)

  • pred_type (('glm', 'nn'), default: 'glm' ) –

    type of posterior predictive, linearized GLM predictive or neural network sampling predictive. The GLM predictive is consistent with the curvature approximations used here.

  • n_samples (int, default: 100 ) –

    number of samples

  • diagonal_output (bool, default: False ) –

    whether to use a diagonalized glm posterior predictive on the outputs. Only applies when pred_type='glm'.

  • generator (Generator, default: None ) –

    random number generator to control the samples (if sampling used)

Returns:

  • samples ( Tensor ) –

    samples (n_samples, batch_size, output_shape)

Source code in laplace/baselaplace.py
def predictive_samples(
    self,
    x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
    pred_type: PredType | str = PredType.GLM,
    n_samples: int = 100,
    diagonal_output: bool = False,
    generator: torch.Generator | None = None,
) -> torch.Tensor:
    """Sample from the posterior predictive on input data `x`.
    Can be used, for example, for Thompson sampling.

    Parameters
    ----------
    x : torch.Tensor or MutableMapping
        input data `(batch_size, input_shape)`

    pred_type : {'glm', 'nn'}, default='glm'
        type of posterior predictive, linearized GLM predictive or neural
        network sampling predictive. The GLM predictive is consistent with
        the curvature approximations used here.

    n_samples : int
        number of samples

    diagonal_output : bool
        whether to use a diagonalized glm posterior predictive on the outputs.
        Only applies when `pred_type='glm'`.

    generator : torch.Generator, optional
        random number generator to control the samples (if sampling used)

    Returns
    -------
    samples : torch.Tensor
        samples `(n_samples, batch_size, output_shape)`
    """
    if pred_type not in PredType.__members__.values():
        raise ValueError("Only glm and nn supported as prediction types.")

    if pred_type == PredType.GLM:
        f_mu, f_var = self._glm_predictive_distribution(x)
        return self._glm_predictive_samples(
            f_mu, f_var, n_samples, diagonal_output, generator
        )

    else:  # 'nn'
        return self._nn_predictive_samples(x, n_samples, generator)

FullLaplace #

FullLaplace(model: Module, likelihood: Likelihood | str, sigma_noise: float | Tensor = 1.0, prior_precision: float | Tensor = 1.0, prior_mean: float | Tensor = 0.0, temperature: float = 1.0, enable_backprop: bool = False, dict_key_x: str = 'input_ids', dict_key_y: str = 'labels', backend: type[CurvatureInterface] | None = None, backend_kwargs: dict[str, Any] | None = None)

Bases: ParametricLaplace

Laplace approximation with full, i.e., dense, log likelihood Hessian approximation and hence posterior precision. Based on the chosen backend parameter, the full approximation can be, for example, a generalized Gauss-Newton matrix. Mathematically, we have \(P \in \mathbb{R}^{P \times P}\). See BaseLaplace for the full interface.

Source code in laplace/baselaplace.py
def __init__(
    self,
    model: nn.Module,
    likelihood: Likelihood | str,
    sigma_noise: float | torch.Tensor = 1.0,
    prior_precision: float | torch.Tensor = 1.0,
    prior_mean: float | torch.Tensor = 0.0,
    temperature: float = 1.0,
    enable_backprop: bool = False,
    dict_key_x: str = "input_ids",
    dict_key_y: str = "labels",
    backend: type[CurvatureInterface] | None = None,
    backend_kwargs: dict[str, Any] | None = None,
):
    super().__init__(
        model,
        likelihood,
        sigma_noise,
        prior_precision,
        prior_mean,
        temperature,
        enable_backprop,
        dict_key_x,
        dict_key_y,
        backend,
        backend_kwargs,
    )
    self._posterior_scale: torch.Tensor | None = None

log_likelihood #

log_likelihood: Tensor

Compute log likelihood on the training data after .fit() has been called. The log likelihood is computed on-demand based on the loss and, for example, the observation noise which makes it differentiable in the latter for iterative updates.

Returns:

  • log_likelihood ( Tensor ) –

prior_precision_diag #

prior_precision_diag: Tensor

Obtain the diagonal prior precision \(p_0\) constructed from either a scalar, layer-wise, or diagonal prior precision.

Returns:

  • prior_precision_diag ( Tensor ) –

scatter #

scatter: Tensor

Computes the scatter, a term of the log marginal likelihood that corresponds to L-2 regularization: scatter = \((\theta_{MAP} - \mu_0)^{T} P_0 (\theta_{MAP} - \mu_0) \).

Returns:

  • scatter ( Tensor ) –

log_det_prior_precision #

log_det_prior_precision: Tensor

Compute log determinant of the prior precision \(\log \det P_0\)

Returns:

  • log_det ( Tensor ) –

log_det_ratio #

log_det_ratio: Tensor

Compute the log determinant ratio, a part of the log marginal likelihood.

\[ \log \frac{\det P}{\det P_0} = \log \det P - \log \det P_0 \]

Returns:

  • log_det_ratio ( Tensor ) –

posterior_scale #

posterior_scale: Tensor

Posterior scale (square root of the covariance), i.e., \(P^{-\frac{1}{2}}\).

Returns:

  • scale ( tensor ) –

    (parameters, parameters)

posterior_covariance #

posterior_covariance: Tensor

Posterior covariance, i.e., \(P^{-1}\).

Returns:

  • covariance ( tensor ) –

    (parameters, parameters)

posterior_precision #

posterior_precision: Tensor

Posterior precision \(P\).

Returns:

  • precision ( tensor ) –

    (parameters, parameters)

log_marginal_likelihood #

log_marginal_likelihood(prior_precision: Tensor | None = None, sigma_noise: Tensor | None = None) -> Tensor

Compute the Laplace approximation to the log marginal likelihood subject to specific Hessian approximations that subclasses implement. Requires that the Laplace approximation has been fit before. The resulting torch.Tensor is differentiable in prior_precision and sigma_noise if these have gradients enabled. By passing prior_precision or sigma_noise, the current value is overwritten. This is useful for iterating on the log marginal likelihood.

Parameters:

  • prior_precision (Tensor, default: None ) –

    prior precision if should be changed from current prior_precision value

  • sigma_noise (Tensor, default: None ) –

    observation noise standard deviation if should be changed

Returns:

  • log_marglik ( Tensor ) –
Source code in laplace/baselaplace.py
def log_marginal_likelihood(
    self,
    prior_precision: torch.Tensor | None = None,
    sigma_noise: torch.Tensor | None = None,
) -> torch.Tensor:
    """Compute the Laplace approximation to the log marginal likelihood subject
    to specific Hessian approximations that subclasses implement.
    Requires that the Laplace approximation has been fit before.
    The resulting torch.Tensor is differentiable in `prior_precision` and
    `sigma_noise` if these have gradients enabled.
    By passing `prior_precision` or `sigma_noise`, the current value is
    overwritten. This is useful for iterating on the log marginal likelihood.

    Parameters
    ----------
    prior_precision : torch.Tensor, optional
        prior precision if should be changed from current `prior_precision` value
    sigma_noise : torch.Tensor, optional
        observation noise standard deviation if should be changed

    Returns
    -------
    log_marglik : torch.Tensor
    """
    # update prior precision (useful when iterating on marglik)
    if prior_precision is not None:
        self.prior_precision = prior_precision

    # update sigma_noise (useful when iterating on marglik)
    if sigma_noise is not None:
        if self.likelihood != Likelihood.REGRESSION:
            raise ValueError("Can only change sigma_noise for regression.")

        self.sigma_noise = sigma_noise

    return self.log_likelihood - 0.5 * (self.log_det_ratio + self.scatter)

__call__ #

__call__(x: Tensor | MutableMapping[str, Tensor | Any], pred_type: PredType | str = PredType.GLM, joint: bool = False, link_approx: LinkApprox | str = LinkApprox.PROBIT, n_samples: int = 100, diagonal_output: bool = False, generator: Generator | None = None, fitting: bool = False, **model_kwargs: dict[str, Any]) -> Tensor | tuple[Tensor, Tensor]

Compute the posterior predictive on input data x.

Parameters:

  • x (Tensor or MutableMapping) –

    (batch_size, input_shape) if tensor. If MutableMapping, must contain the said tensor.

  • pred_type (('glm', 'nn'), default: 'glm' ) –

    type of posterior predictive, linearized GLM predictive or neural network sampling predictive. The GLM predictive is consistent with the curvature approximations used here. When Laplace is done only on subset of parameters (i.e. some grad are disabled), only nn predictive is supported.

  • link_approx (('mc', 'probit', 'bridge', 'bridge_norm'), default: 'mc' ) –

    how to approximate the classification link function for the 'glm'. For pred_type='nn', only 'mc' is possible.

  • joint (bool, default: False ) –

    Whether to output a joint predictive distribution in regression with pred_type='glm'. If set to True, the predictive distribution has the same form as GP posterior, i.e. N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]). If False, then only outputs the marginal predictive distribution. Only available for regression and GLM predictive.

  • n_samples (int, default: 100 ) –

    number of samples for link_approx='mc'.

  • diagonal_output (bool, default: False ) –

    whether to use a diagonalized posterior predictive on the outputs. Only works for pred_type='glm' when joint=False in regression. In the case of last-layer Laplace with a diagonal or Kron Hessian, setting this to True makes computation much(!) faster for large number of outputs.

  • generator (Generator, default: None ) –

    random number generator to control the samples (if sampling used).

  • fitting (bool, default: False ) –

    whether or not this predictive call is done during fitting. Only useful for reward modeling: the likelihood is set to "regression" when False and "classification" when True.

Returns:

  • predictive ( Tensor or tuple[Tensor] ) –

    For likelihood='classification', a torch.Tensor is returned with a distribution over classes (similar to a Softmax). For likelihood='regression', a tuple of torch.Tensor is returned with the mean and the predictive variance. For likelihood='regression' and joint=True, a tuple of torch.Tensor is returned with the mean and the predictive covariance.

Source code in laplace/baselaplace.py
def __call__(
    self,
    x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
    pred_type: PredType | str = PredType.GLM,
    joint: bool = False,
    link_approx: LinkApprox | str = LinkApprox.PROBIT,
    n_samples: int = 100,
    diagonal_output: bool = False,
    generator: torch.Generator | None = None,
    fitting: bool = False,
    **model_kwargs: dict[str, Any],
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
    """Compute the posterior predictive on input data `x`.

    Parameters
    ----------
    x : torch.Tensor or MutableMapping
        `(batch_size, input_shape)` if tensor. If MutableMapping, must contain
        the said tensor.

    pred_type : {'glm', 'nn'}, default='glm'
        type of posterior predictive, linearized GLM predictive or neural
        network sampling predictive. The GLM predictive is consistent with
        the curvature approximations used here. When Laplace is done only
        on subset of parameters (i.e. some grad are disabled),
        only `nn` predictive is supported.

    link_approx : {'mc', 'probit', 'bridge', 'bridge_norm'}
        how to approximate the classification link function for the `'glm'`.
        For `pred_type='nn'`, only 'mc' is possible.

    joint : bool
        Whether to output a joint predictive distribution in regression with
        `pred_type='glm'`. If set to `True`, the predictive distribution
        has the same form as GP posterior, i.e. N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]).
        If `False`, then only outputs the marginal predictive distribution.
        Only available for regression and GLM predictive.

    n_samples : int
        number of samples for `link_approx='mc'`.

    diagonal_output : bool
        whether to use a diagonalized posterior predictive on the outputs.
        Only works for `pred_type='glm'` when `joint=False` in regression.
        In the case of last-layer Laplace with a diagonal or Kron Hessian,
        setting this to `True` makes computation much(!) faster for large
        number of outputs.

    generator : torch.Generator, optional
        random number generator to control the samples (if sampling used).

    fitting : bool, default=False
        whether or not this predictive call is done during fitting. Only useful for
        reward modeling: the likelihood is set to `"regression"` when `False` and
        `"classification"` when `True`.

    Returns
    -------
    predictive: torch.Tensor or tuple[torch.Tensor]
        For `likelihood='classification'`, a torch.Tensor is returned with
        a distribution over classes (similar to a Softmax).
        For `likelihood='regression'`, a tuple of torch.Tensor is returned
        with the mean and the predictive variance.
        For `likelihood='regression'` and `joint=True`, a tuple of torch.Tensor
        is returned with the mean and the predictive covariance.
    """
    if pred_type not in [pred for pred in PredType]:
        raise ValueError("Only glm and nn supported as prediction types.")

    if link_approx not in [la for la in LinkApprox]:
        raise ValueError(f"Unsupported link approximation {link_approx}.")

    if pred_type == PredType.NN and link_approx != LinkApprox.MC:
        raise ValueError(
            "Only mc link approximation is supported for nn prediction type."
        )

    if generator is not None:
        if (
            not isinstance(generator, torch.Generator)
            or generator.device != self._device
        ):
            raise ValueError("Invalid random generator (check type and device).")

    likelihood = self.likelihood
    if likelihood == Likelihood.REWARD_MODELING:
        likelihood = Likelihood.CLASSIFICATION if fitting else Likelihood.REGRESSION

    if pred_type == PredType.GLM:
        return self._glm_forward_call(
            x, likelihood, joint, link_approx, n_samples, diagonal_output
        )
    else:
        if likelihood == Likelihood.REGRESSION:
            samples = self._nn_predictive_samples(x, n_samples, **model_kwargs)
            return samples.mean(dim=0), samples.var(dim=0)
        else:  # classification; the average is computed online
            return self._nn_predictive_classification(x, n_samples, **model_kwargs)

_glm_forward_call #

_glm_forward_call(x: Tensor | MutableMapping, likelihood: Likelihood | str, joint: bool = False, link_approx: LinkApprox | str = LinkApprox.PROBIT, n_samples: int = 100, diagonal_output: bool = False) -> Tensor | tuple[Tensor, Tensor]

Compute the posterior predictive on input data x for "glm" pred type.

Parameters:

  • x (Tensor or MutableMapping) –

    (batch_size, input_shape) if tensor. If MutableMapping, must contain the said tensor.

  • likelihood (Likelihood or str in {'classification', 'regression', 'reward_modeling'}) –

    determines the log likelihood Hessian approximation.

  • link_approx (('mc', 'probit', 'bridge', 'bridge_norm'), default: 'mc' ) –

    how to approximate the classification link function for the 'glm'. For pred_type='nn', only 'mc' is possible.

  • joint (bool, default: False ) –

    Whether to output a joint predictive distribution in regression with pred_type='glm'. If set to True, the predictive distribution has the same form as GP posterior, i.e. N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]). If False, then only outputs the marginal predictive distribution. Only available for regression and GLM predictive.

  • n_samples (int, default: 100 ) –

    number of samples for link_approx='mc'.

  • diagonal_output (bool, default: False ) –

    whether to use a diagonalized posterior predictive on the outputs. Only works for pred_type='glm' and link_approx='mc'.

Returns:

  • predictive ( Tensor or tuple[Tensor] ) –

    For likelihood='classification', a torch.Tensor is returned with a distribution over classes (similar to a Softmax). For likelihood='regression', a tuple of torch.Tensor is returned with the mean and the predictive variance. For likelihood='regression' and joint=True, a tuple of torch.Tensor is returned with the mean and the predictive covariance.

Source code in laplace/baselaplace.py
def _glm_forward_call(
    self,
    x: torch.Tensor | MutableMapping,
    likelihood: Likelihood | str,
    joint: bool = False,
    link_approx: LinkApprox | str = LinkApprox.PROBIT,
    n_samples: int = 100,
    diagonal_output: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
    """Compute the posterior predictive on input data `x` for "glm" pred type.

    Parameters
    ----------
    x : torch.Tensor or MutableMapping
        `(batch_size, input_shape)` if tensor. If MutableMapping, must contain
        the said tensor.

    likelihood : Likelihood or str in {'classification', 'regression', 'reward_modeling'}
        determines the log likelihood Hessian approximation.

    link_approx : {'mc', 'probit', 'bridge', 'bridge_norm'}
        how to approximate the classification link function for the `'glm'`.
        For `pred_type='nn'`, only 'mc' is possible.

    joint : bool
        Whether to output a joint predictive distribution in regression with
        `pred_type='glm'`. If set to `True`, the predictive distribution
        has the same form as GP posterior, i.e. N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]).
        If `False`, then only outputs the marginal predictive distribution.
        Only available for regression and GLM predictive.

    n_samples : int
        number of samples for `link_approx='mc'`.

    diagonal_output : bool
        whether to use a diagonalized posterior predictive on the outputs.
        Only works for `pred_type='glm'` and `link_approx='mc'`.

    Returns
    -------
    predictive: torch.Tensor or tuple[torch.Tensor]
        For `likelihood='classification'`, a torch.Tensor is returned with
        a distribution over classes (similar to a Softmax).
        For `likelihood='regression'`, a tuple of torch.Tensor is returned
        with the mean and the predictive variance.
        For `likelihood='regression'` and `joint=True`, a tuple of torch.Tensor
        is returned with the mean and the predictive covariance.
    """
    f_mu, f_var = self._glm_predictive_distribution(
        x, joint=joint and likelihood == Likelihood.REGRESSION
    )

    if likelihood == Likelihood.REGRESSION:
        if diagonal_output and not joint:
            f_var = torch.diagonal(f_var, dim1=-2, dim2=-1)
        return f_mu, f_var

    if link_approx == LinkApprox.MC:
        return self._glm_predictive_samples(
            f_mu,
            f_var,
            n_samples=n_samples,
            diagonal_output=diagonal_output,
        ).mean(dim=0)
    elif link_approx == LinkApprox.PROBIT:
        kappa = 1 / torch.sqrt(1.0 + np.pi / 8 * f_var.diagonal(dim1=1, dim2=2))
        return torch.softmax(kappa * f_mu, dim=-1)
    elif "bridge" in link_approx:
        # zero mean correction
        f_mu -= (
            f_var.sum(-1)
            * f_mu.sum(-1).reshape(-1, 1)
            / f_var.sum(dim=(1, 2)).reshape(-1, 1)
        )
        f_var -= torch.einsum(
            "bi,bj->bij", f_var.sum(-1), f_var.sum(-2)
        ) / f_var.sum(dim=(1, 2)).reshape(-1, 1, 1)

        # Laplace Bridge
        _, K = f_mu.size(0), f_mu.size(-1)
        f_var_diag = torch.diagonal(f_var, dim1=1, dim2=2)

        # optional: variance correction
        if link_approx == LinkApprox.BRIDGE_NORM:
            f_var_diag_mean = f_var_diag.mean(dim=1)
            f_var_diag_mean /= torch.as_tensor([K / 2], device=self._device).sqrt()
            f_mu /= f_var_diag_mean.sqrt().unsqueeze(-1)
            f_var_diag /= f_var_diag_mean.unsqueeze(-1)

        sum_exp = torch.exp(-f_mu).sum(dim=1).unsqueeze(-1)
        alpha = (1 - 2 / K + f_mu.exp() / K**2 * sum_exp) / f_var_diag
        return torch.nan_to_num(alpha / alpha.sum(dim=1).unsqueeze(-1), nan=1.0)
    else:
        raise ValueError(
            "Prediction path invalid. Check the likelihood, pred_type, link_approx combination!"
        )

_glm_predictive_samples #

_glm_predictive_samples(f_mu: Tensor, f_var: Tensor, n_samples: int, diagonal_output: bool = False, generator: Generator | None = None) -> Tensor

Sample from the posterior predictive on input data x using "glm" prediction type.

Parameters:

  • f_mu (Tensor or MutableMapping) –

    glm predictive mean (batch_size, output_shape)

  • f_var (Tensor or MutableMapping) –

    glm predictive covariances (batch_size, output_shape, output_shape)

  • n_samples (int) –

    number of samples

  • diagonal_output (bool, default: False ) –

    whether to use a diagonalized glm posterior predictive on the outputs.

  • generator (Generator, default: None ) –

    random number generator to control the samples (if sampling used)

Returns:

  • samples ( Tensor ) –

    samples (n_samples, batch_size, output_shape)

Source code in laplace/baselaplace.py
def _glm_predictive_samples(
    self,
    f_mu: torch.Tensor,
    f_var: torch.Tensor,
    n_samples: int,
    diagonal_output: bool = False,
    generator: torch.Generator | None = None,
) -> torch.Tensor:
    """Sample from the posterior predictive on input data `x` using "glm" prediction
    type.

    Parameters
    ----------
    f_mu : torch.Tensor or MutableMapping
        glm predictive mean `(batch_size, output_shape)`

    f_var : torch.Tensor or MutableMapping
        glm predictive covariances `(batch_size, output_shape, output_shape)`

    n_samples : int
        number of samples

    diagonal_output : bool
        whether to use a diagonalized glm posterior predictive on the outputs.

    generator : torch.Generator, optional
        random number generator to control the samples (if sampling used)

    Returns
    -------
    samples : torch.Tensor
        samples `(n_samples, batch_size, output_shape)`
    """
    assert f_var.shape == torch.Size([f_mu.shape[0], f_mu.shape[1], f_mu.shape[1]])

    if diagonal_output:
        f_var = torch.diagonal(f_var, dim1=1, dim2=2)

    f_samples = normal_samples(f_mu, f_var, n_samples, generator)

    if self.likelihood == Likelihood.REGRESSION:
        return f_samples
    else:
        return torch.softmax(f_samples, dim=-1)

log_prob #

log_prob(value: Tensor, normalized: bool = True) -> Tensor

Compute the log probability under the (current) Laplace approximation.

Parameters:

  • value (Tensor) –
  • normalized (bool, default: True ) –

    whether to return log of a properly normalized Gaussian or just the terms that depend on value.

Returns:

  • log_prob ( Tensor ) –
Source code in laplace/baselaplace.py
def log_prob(self, value: torch.Tensor, normalized: bool = True) -> torch.Tensor:
    """Compute the log probability under the (current) Laplace approximation.

    Parameters
    ----------
    value: torch.Tensor
    normalized : bool, default=True
        whether to return log of a properly normalized Gaussian or just the
        terms that depend on `value`.

    Returns
    -------
    log_prob : torch.Tensor
    """
    if not normalized:
        return -self.square_norm(value) / 2
    log_prob = (
        -self.n_params / 2 * log(2 * pi) + self.log_det_posterior_precision / 2
    )
    log_prob -= self.square_norm(value) / 2
    return log_prob

predictive_samples #

predictive_samples(x: Tensor | MutableMapping[str, Tensor | Any], pred_type: PredType | str = PredType.GLM, n_samples: int = 100, diagonal_output: bool = False, generator: Generator | None = None) -> Tensor

Sample from the posterior predictive on input data x. Can be used, for example, for Thompson sampling.

Parameters:

  • x (Tensor or MutableMapping) –

    input data (batch_size, input_shape)

  • pred_type (('glm', 'nn'), default: 'glm' ) –

    type of posterior predictive, linearized GLM predictive or neural network sampling predictive. The GLM predictive is consistent with the curvature approximations used here.

  • n_samples (int, default: 100 ) –

    number of samples

  • diagonal_output (bool, default: False ) –

    whether to use a diagonalized glm posterior predictive on the outputs. Only applies when pred_type='glm'.

  • generator (Generator, default: None ) –

    random number generator to control the samples (if sampling used)

Returns:

  • samples ( Tensor ) –

    samples (n_samples, batch_size, output_shape)

Source code in laplace/baselaplace.py
def predictive_samples(
    self,
    x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
    pred_type: PredType | str = PredType.GLM,
    n_samples: int = 100,
    diagonal_output: bool = False,
    generator: torch.Generator | None = None,
) -> torch.Tensor:
    """Sample from the posterior predictive on input data `x`.
    Can be used, for example, for Thompson sampling.

    Parameters
    ----------
    x : torch.Tensor or MutableMapping
        input data `(batch_size, input_shape)`

    pred_type : {'glm', 'nn'}, default='glm'
        type of posterior predictive, linearized GLM predictive or neural
        network sampling predictive. The GLM predictive is consistent with
        the curvature approximations used here.

    n_samples : int
        number of samples

    diagonal_output : bool
        whether to use a diagonalized glm posterior predictive on the outputs.
        Only applies when `pred_type='glm'`.

    generator : torch.Generator, optional
        random number generator to control the samples (if sampling used)

    Returns
    -------
    samples : torch.Tensor
        samples `(n_samples, batch_size, output_shape)`
    """
    if pred_type not in PredType.__members__.values():
        raise ValueError("Only glm and nn supported as prediction types.")

    if pred_type == PredType.GLM:
        f_mu, f_var = self._glm_predictive_distribution(x)
        return self._glm_predictive_samples(
            f_mu, f_var, n_samples, diagonal_output, generator
        )

    else:  # 'nn'
        return self._nn_predictive_samples(x, n_samples, generator)