Skip to content

laplace.baselaplace #

FunctionalLaplace #

FunctionalLaplace(model: Module, likelihood: Likelihood | str, n_subset: int, sigma_noise: float | Tensor = 1.0, prior_precision: float | Tensor = 1.0, prior_mean: float | Tensor = 0.0, temperature: float = 1.0, enable_backprop: bool = False, dict_key_x='input_ids', dict_key_y='labels', backend: type[CurvatureInterface] | None = BackPackGGN, backend_kwargs: dict[str, Any] | None = None, independent_outputs: bool = False, seed: int = 0)

Bases: BaseLaplace

Applying the GGN (Generalized Gauss-Newton) approximation for the Hessian in the Laplace approximation of the posterior turns the underlying probabilistic model from a BNN into a GLM (generalized linear model). This GLM (in the weight space) is equivalent to a GP (in the function space), see Approximate Inference Turns Deep Networks into Gaussian Processes (Khan et al., 2019)

This class implements the (approximate) GP inference through which we obtain the desired quantities (posterior predictive, marginal log-likelihood). See Improving predictions of Bayesian neural nets via local linearization (Immer et al., 2021) for more details.

Note that for likelihood='classification', we approximate \( L_{NN} \) with a diagonal matrix ( \( L_{NN} \) is a block-diagonal matrix, where blocks represent Hessians of per-data-point log-likelihood w.r.t. neural network output \( f \), See Appendix A.2.1 for exact definition). We resort to such an approximation because of the (possible) errors found in Laplace approximation for multiclass GP classification in Chapter 3.5 of R&W 2006 GP book, see the question here for more details. Alternatively, one could also resort to one-vs-one or one-vs-rest implementations for multiclass classification, however, that is not (yet) supported here.

Parameters:

  • num_data (int) –

    number of data points for Subset-of-Data (SOD) approximate GP inference.

  • diagonal_kernel (bool) –

    GP kernel here is product of Jacobians, which results in a \( C \times C\) matrix where \(C\) is the output dimension. If diagonal_kernel=True, only a diagonal of a GP kernel is used. This is (somewhat) equivalent to assuming independent GPs across output channels.

  • See
Source code in laplace/baselaplace.py
def __init__(
    self,
    model: nn.Module,
    likelihood: Likelihood | str,
    n_subset: int,
    sigma_noise: float | torch.Tensor = 1.0,
    prior_precision: float | torch.Tensor = 1.0,
    prior_mean: float | torch.Tensor = 0.0,
    temperature: float = 1.0,
    enable_backprop: bool = False,
    dict_key_x="input_ids",
    dict_key_y="labels",
    backend: type[CurvatureInterface] | None = BackPackGGN,
    backend_kwargs: dict[str, Any] | None = None,
    independent_outputs: bool = False,
    seed: int = 0,
):
    assert backend in [BackPackGGN, AsdlGGN, CurvlinopsGGN]
    self._check_prior_precision(prior_precision)
    super().__init__(
        model,
        likelihood,
        sigma_noise,
        prior_precision,
        prior_mean,
        temperature,
        enable_backprop,
        dict_key_x,
        dict_key_y,
        backend,
        backend_kwargs,
    )
    self.enable_backprop = enable_backprop

    self.n_subset = n_subset
    self.independent_outputs = independent_outputs
    self.seed = seed

    self.K_MM = None
    self.Sigma_inv = None  # (K_{MM} + L_MM_inv)^{-1}
    self.train_loader = (
        None  # needed in functional variance and marginal log likelihood
    )
    self.batch_size = None
    self._prior_factor_sod = None
    self.mu = None  # mean in the scatter term of the log marginal likelihood
    self.L = None

    # Posterior mean (used in regression marginal likelihood)
    self.mean = parameters_to_vector(self.model.parameters()).detach()

    self._fitted = False
    self._recompute_Sigma = True

log_likelihood #

log_likelihood: Tensor

Compute log likelihood on the training data after .fit() has been called. The log likelihood is computed on-demand based on the loss and, for example, the observation noise which makes it differentiable in the latter for iterative updates.

Returns:

  • log_likelihood ( Tensor ) –

prior_precision_diag #

prior_precision_diag: Tensor

Obtain the diagonal prior precision \(p_0\) constructed from either a scalar, layer-wise, or diagonal prior precision.

Returns:

  • prior_precision_diag ( Tensor ) –

log_det_ratio #

log_det_ratio: Tensor

Computes log determinant term in GP marginal likelihood

For classification we use eq. (3.44) from Chapter 3.5 from GP book R&W 2006 with (note that we always use diagonal approximation \(D\) of the Hessian of log likelihood w.r.t. \(f\)):

log determinant term := \( \log | I + D^{1/2}K D^{1/2} | \)

For regression, we use "standard" GP marginal likelihood:

log determinant term := \( \log | K + \sigma_2 I | \)

scatter #

scatter: Tensor

Compute scatter term in GP log marginal likelihood.

For classification we use eq. (3.44) from Chapter 3.5 from GP book R&W 2006 with \(\hat{f} = f \):

scatter term := \( f K^{-1} f^{T} \)

For regression, we use "standard" GP marginal likelihood:

scatter term := \( (y - m)K^{-1}(y -m )^T \), where \( m \) is the mean of the GP prior, which in our case corresponds to \( m := f + J (\theta - \theta_{MAP}) \)

_glm_forward_call #

_glm_forward_call(x: Tensor | MutableMapping, likelihood: Likelihood | str, joint: bool = False, link_approx: LinkApprox | str = LinkApprox.PROBIT, n_samples: int = 100, diagonal_output: bool = False) -> Tensor | tuple[Tensor, Tensor]

Compute the posterior predictive on input data x for "glm" pred type.

Parameters:

  • x (Tensor or MutableMapping) –

    (batch_size, input_shape) if tensor. If MutableMapping, must contain the said tensor.

  • likelihood (Likelihood or str in {'classification', 'regression', 'reward_modeling'}) –

    determines the log likelihood Hessian approximation.

  • link_approx (('mc', 'probit', 'bridge', 'bridge_norm'), default: 'mc' ) –

    how to approximate the classification link function for the 'glm'. For pred_type='nn', only 'mc' is possible.

  • joint (bool, default: False ) –

    Whether to output a joint predictive distribution in regression with pred_type='glm'. If set to True, the predictive distribution has the same form as GP posterior, i.e. N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]). If False, then only outputs the marginal predictive distribution. Only available for regression and GLM predictive.

  • n_samples (int, default: 100 ) –

    number of samples for link_approx='mc'.

  • diagonal_output (bool, default: False ) –

    whether to use a diagonalized posterior predictive on the outputs. Only works for pred_type='glm' and link_approx='mc'.

Returns:

  • predictive ( Tensor or tuple[Tensor] ) –

    For likelihood='classification', a torch.Tensor is returned with a distribution over classes (similar to a Softmax). For likelihood='regression', a tuple of torch.Tensor is returned with the mean and the predictive variance. For likelihood='regression' and joint=True, a tuple of torch.Tensor is returned with the mean and the predictive covariance.

Source code in laplace/baselaplace.py
def _glm_forward_call(
    self,
    x: torch.Tensor | MutableMapping,
    likelihood: Likelihood | str,
    joint: bool = False,
    link_approx: LinkApprox | str = LinkApprox.PROBIT,
    n_samples: int = 100,
    diagonal_output: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
    """Compute the posterior predictive on input data `x` for "glm" pred type.

    Parameters
    ----------
    x : torch.Tensor or MutableMapping
        `(batch_size, input_shape)` if tensor. If MutableMapping, must contain
        the said tensor.

    likelihood : Likelihood or str in {'classification', 'regression', 'reward_modeling'}
        determines the log likelihood Hessian approximation.

    link_approx : {'mc', 'probit', 'bridge', 'bridge_norm'}
        how to approximate the classification link function for the `'glm'`.
        For `pred_type='nn'`, only 'mc' is possible.

    joint : bool
        Whether to output a joint predictive distribution in regression with
        `pred_type='glm'`. If set to `True`, the predictive distribution
        has the same form as GP posterior, i.e. N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]).
        If `False`, then only outputs the marginal predictive distribution.
        Only available for regression and GLM predictive.

    n_samples : int
        number of samples for `link_approx='mc'`.

    diagonal_output : bool
        whether to use a diagonalized posterior predictive on the outputs.
        Only works for `pred_type='glm'` and `link_approx='mc'`.

    Returns
    -------
    predictive: torch.Tensor or tuple[torch.Tensor]
        For `likelihood='classification'`, a torch.Tensor is returned with
        a distribution over classes (similar to a Softmax).
        For `likelihood='regression'`, a tuple of torch.Tensor is returned
        with the mean and the predictive variance.
        For `likelihood='regression'` and `joint=True`, a tuple of torch.Tensor
        is returned with the mean and the predictive covariance.
    """
    f_mu, f_var = self._glm_predictive_distribution(
        x, joint=joint and likelihood == Likelihood.REGRESSION
    )

    if likelihood == Likelihood.REGRESSION:
        if diagonal_output and not joint:
            f_var = torch.diagonal(f_var, dim1=-2, dim2=-1)
        return f_mu, f_var

    if link_approx == LinkApprox.MC:
        return self._glm_predictive_samples(
            f_mu,
            f_var,
            n_samples=n_samples,
            diagonal_output=diagonal_output,
        ).mean(dim=0)
    elif link_approx == LinkApprox.PROBIT:
        kappa = 1 / torch.sqrt(1.0 + np.pi / 8 * f_var.diagonal(dim1=1, dim2=2))
        return torch.softmax(kappa * f_mu, dim=-1)
    elif "bridge" in link_approx:
        # zero mean correction
        f_mu -= (
            f_var.sum(-1)
            * f_mu.sum(-1).reshape(-1, 1)
            / f_var.sum(dim=(1, 2)).reshape(-1, 1)
        )
        f_var -= torch.einsum(
            "bi,bj->bij", f_var.sum(-1), f_var.sum(-2)
        ) / f_var.sum(dim=(1, 2)).reshape(-1, 1, 1)

        # Laplace Bridge
        _, K = f_mu.size(0), f_mu.size(-1)
        f_var_diag = torch.diagonal(f_var, dim1=1, dim2=2)

        # optional: variance correction
        if link_approx == LinkApprox.BRIDGE_NORM:
            f_var_diag_mean = f_var_diag.mean(dim=1)
            f_var_diag_mean /= torch.as_tensor([K / 2], device=self._device).sqrt()
            f_mu /= f_var_diag_mean.sqrt().unsqueeze(-1)
            f_var_diag /= f_var_diag_mean.unsqueeze(-1)

        sum_exp = torch.exp(-f_mu).sum(dim=1).unsqueeze(-1)
        alpha = (1 - 2 / K + f_mu.exp() / K**2 * sum_exp) / f_var_diag
        return torch.nan_to_num(alpha / alpha.sum(dim=1).unsqueeze(-1), nan=1.0)
    else:
        raise ValueError(
            "Prediction path invalid. Check the likelihood, pred_type, link_approx combination!"
        )

_glm_predictive_samples #

_glm_predictive_samples(f_mu: Tensor, f_var: Tensor, n_samples: int, diagonal_output: bool = False, generator: Generator | None = None) -> Tensor

Sample from the posterior predictive on input data x using "glm" prediction type.

Parameters:

  • f_mu (Tensor or MutableMapping) –

    glm predictive mean (batch_size, output_shape)

  • f_var (Tensor or MutableMapping) –

    glm predictive covariances (batch_size, output_shape, output_shape)

  • n_samples (int) –

    number of samples

  • diagonal_output (bool, default: False ) –

    whether to use a diagonalized glm posterior predictive on the outputs.

  • generator (Generator, default: None ) –

    random number generator to control the samples (if sampling used)

Returns:

  • samples ( Tensor ) –

    samples (n_samples, batch_size, output_shape)

Source code in laplace/baselaplace.py
def _glm_predictive_samples(
    self,
    f_mu: torch.Tensor,
    f_var: torch.Tensor,
    n_samples: int,
    diagonal_output: bool = False,
    generator: torch.Generator | None = None,
) -> torch.Tensor:
    """Sample from the posterior predictive on input data `x` using "glm" prediction
    type.

    Parameters
    ----------
    f_mu : torch.Tensor or MutableMapping
        glm predictive mean `(batch_size, output_shape)`

    f_var : torch.Tensor or MutableMapping
        glm predictive covariances `(batch_size, output_shape, output_shape)`

    n_samples : int
        number of samples

    diagonal_output : bool
        whether to use a diagonalized glm posterior predictive on the outputs.

    generator : torch.Generator, optional
        random number generator to control the samples (if sampling used)

    Returns
    -------
    samples : torch.Tensor
        samples `(n_samples, batch_size, output_shape)`
    """
    assert f_var.shape == torch.Size([f_mu.shape[0], f_mu.shape[1], f_mu.shape[1]])

    if diagonal_output:
        f_var = torch.diagonal(f_var, dim1=1, dim2=2)

    f_samples = normal_samples(f_mu, f_var, n_samples, generator)

    if self.likelihood == Likelihood.REGRESSION:
        return f_samples
    else:
        return torch.softmax(f_samples, dim=-1)

_check_prior_precision #

_check_prior_precision(prior_precision: float | Tensor)

Checks if the given prior precision is suitable for the GP interpretation of LLA. As such, only single value priors, i.e., isotropic priors are suitable.

Source code in laplace/baselaplace.py
@staticmethod
def _check_prior_precision(prior_precision: float | torch.Tensor):
    """Checks if the given prior precision is suitable for the GP interpretation of LLA.
    As such, only single value priors, i.e., isotropic priors are suitable.
    """
    if torch.is_tensor(prior_precision):
        if not (
            prior_precision.ndim == 0
            or (prior_precision.ndim == 1 and len(prior_precision) == 1)
        ):
            raise ValueError("Only isotropic priors supported in FunctionalLaplace")

_init_K_MM #

_init_K_MM()

Allocates memory for the kernel matrix evaluated at the subset of the training data points. If the subset is of size \(M\) and the problem has \(C\) outputs, this is a list of C \((M,M\)) tensors for diagonal kernel and \((M x C, M x C)\) otherwise.

Source code in laplace/baselaplace.py
def _init_K_MM(self):
    """Allocates memory for the kernel matrix evaluated at the subset of the training
    data points. If the subset is of size \\(M\\) and the problem has \\(C\\) outputs,
    this is a list of C \\((M,M\\)) tensors for diagonal kernel and \\((M x C, M x C)\\)
    otherwise.
    """
    if self.independent_outputs:
        self.K_MM = [
            torch.empty(size=(self.n_subset, self.n_subset), device=self._device)
            for _ in range(self.n_outputs)
        ]
    else:
        self.K_MM = torch.empty(
            size=(self.n_subset * self.n_outputs, self.n_subset * self.n_outputs),
            device=self._device,
        )

_init_Sigma_inv #

_init_Sigma_inv()

Allocates memory for the cholesky decomposition of [ K_{MM} + \Lambda_{MM}^{-1}. ] See See Improving predictions of Bayesian neural nets via local linearization (Immer et al., 2021) Equation 15 for more information.

Source code in laplace/baselaplace.py
def _init_Sigma_inv(self):
    """Allocates memory for the cholesky decomposition of
    \\[
        K_{MM} + \\Lambda_{MM}^{-1}.
    \\]
    See See [Improving predictions of Bayesian neural nets via local linearization (Immer et al., 2021)](https://arxiv.org/abs/2008.08400)
    Equation 15 for more information.
    """
    if self.independent_outputs:
        self.Sigma_inv = [
            torch.empty(size=(self.n_subset, self.n_subset), device=self._device)
            for _ in range(self.n_outputs)
        ]
    else:
        self.Sigma_inv = torch.empty(
            size=(self.n_subset * self.n_outputs, self.n_subset * self.n_outputs),
            device=self._device,
        )

_store_K_batch #

_store_K_batch(K_batch: Tensor, i: int, j: int)

Given the kernel matrix between the i-th and the j-th batch, stores it in the corresponding position in self.K_MM.

Source code in laplace/baselaplace.py
def _store_K_batch(self, K_batch: torch.Tensor, i: int, j: int):
    """Given the kernel matrix between the i-th and the j-th batch, stores it in the
    corresponding position in self.K_MM.
    """
    if self.independent_outputs:
        for c in range(self.n_outputs):
            self.K_MM[c][
                i * self.batch_size : min((i + 1) * self.batch_size, self.n_subset),
                j * self.batch_size : min((j + 1) * self.batch_size, self.n_subset),
            ] = K_batch[:, :, c]
            if i != j:
                self.K_MM[c][
                    j * self.batch_size : min(
                        (j + 1) * self.batch_size, self.n_subset
                    ),
                    i * self.batch_size : min(
                        (i + 1) * self.batch_size, self.n_subset
                    ),
                ] = torch.transpose(K_batch[:, :, c], 0, 1)
    else:
        bC = self.batch_size * self.n_outputs
        MC = self.n_subset * self.n_outputs
        self.K_MM[
            i * bC : min((i + 1) * bC, MC), j * bC : min((j + 1) * bC, MC)
        ] = K_batch
        if i != j:
            self.K_MM[
                j * bC : min((j + 1) * bC, MC), i * bC : min((i + 1) * bC, MC)
            ] = torch.transpose(K_batch, 0, 1)

_build_L #

_build_L(lambdas: list[Tensor])

Given a list of the Hessians of per-batch log-likelihood w.r.t. neural network output \( f \), returns the contatenation of these hessians in a suitable format for the used kernel (diagonal or not).

In this function the diagonal approximation is performed. Please refer to the introduction of the class for more details.

Parameters:

  • lambdas (list of torch.Tensor of shape (C, C)) –
      Contains per-batch log-likelihood w.r.t. neural network output \( f \).
    

Returns:

  • L ( list with length C of tensors with shape M or tensor (MxC) ) –

    Contains the given Hessians in a suitable format.

Source code in laplace/baselaplace.py
def _build_L(self, lambdas: list[torch.Tensor]):
    """Given a list of the Hessians of per-batch log-likelihood w.r.t. neural network output \\( f \\),
    returns the contatenation of these hessians in a suitable format for the used kernel
    (diagonal or not).

    In this function the diagonal approximation is performed. Please refer to the introduction of the
    class for more details.

    Parameters
    ----------
    lambdas : list of torch.Tensor of shape (C, C)
              Contains per-batch log-likelihood w.r.t. neural network output \\( f \\).

    Returns
    -------
    L : list with length C of tensors with shape M or tensor (MxC)
        Contains the given Hessians in a suitable format.
    """
    # Concatenate batch dimension and discard non-diagonal entries.
    L_diag = torch.diagonal(torch.cat(lambdas, dim=0), dim1=-2, dim2=-1).reshape(-1)

    if self.independent_outputs:
        return [L_diag[i :: self.n_outputs] for i in range(self.n_outputs)]
    else:
        return L_diag

_build_Sigma_inv #

_build_Sigma_inv()

Computes the cholesky decomposition of [ K_{MM} + \Lambda_{MM}^{-1}. ] See See Improving predictions of Bayesian neural nets via local linearization (Immer et al., 2021) Equation 15 for more information.

<<<<<<< HEAD As the diagonal approximation is performed with \Lambda_{MM} (which is stored in self.L), ======= As the diagonal approximation is performed with \(\Lambda_{MM}\) (which is stored in self.L),

main the code is greatly simplified.

Source code in laplace/baselaplace.py
def _build_Sigma_inv(self):
    """Computes the cholesky decomposition of
            \\[
                K_{MM} + \\Lambda_{MM}^{-1}.
            \\]
            See See [Improving predictions of Bayesian neural nets via local linearization (Immer et al., 2021)](https://arxiv.org/abs/2008.08400)
            Equation 15 for more information.

    <<<<<<< HEAD
            As the diagonal approximation is performed with \\Lambda_{MM} (which is stored in self.L),
    =======
            As the diagonal approximation is performed with \\(\\Lambda_{MM}\\) (which is stored in self.L),
    >>>>>>> main
            the code is greatly simplified.
    """
    if self.independent_outputs:
        self.Sigma_inv = [
            torch.linalg.cholesky(
                self.gp_kernel_prior_variance * self.K_MM[c]
                + torch.diag(
                    torch.nan_to_num(1.0 / (self._H_factor * lambda_c), posinf=10.0)
                )
            )
            for c, lambda_c in enumerate(self.L)
        ]
    else:
        self.Sigma_inv = torch.linalg.cholesky(
            self.gp_kernel_prior_variance * self.K_MM
            + torch.diag(
                torch.nan_to_num(1 / (self._H_factor * self.L), posinf=10.0)
            )
        )

_get_SoD_data_loader #

_get_SoD_data_loader(train_loader: DataLoader) -> DataLoader

Subset-of-Datapoints data loader

Source code in laplace/baselaplace.py
def _get_SoD_data_loader(self, train_loader: DataLoader) -> DataLoader:
    """Subset-of-Datapoints data loader"""
    return DataLoader(
        dataset=train_loader.dataset,
        batch_size=train_loader.batch_size,
        sampler=SoDSampler(
            N=len(train_loader.dataset), M=self.n_subset, seed=self.seed
        ),
        shuffle=False,
    )

fit #

fit(train_loader: DataLoader | MutableMapping, progress_bar: bool = False)

Fit the Laplace approximation of a GP posterior.

Parameters:

  • train_loader (DataLoader) –

    train_loader.dataset needs to be set to access \(N\), size of the data set train_loader.batch_size needs to be set to access \(b\) batch_size

  • progress_bar (bool, default: False ) –

    whether to show a progress bar during the fitting process.

Source code in laplace/baselaplace.py
def fit(
    self, train_loader: DataLoader | MutableMapping, progress_bar: bool = False
):
    """Fit the Laplace approximation of a GP posterior.

    Parameters
    ----------
    train_loader : torch.data.utils.DataLoader
        `train_loader.dataset` needs to be set to access \\(N\\), size of the data set
        `train_loader.batch_size` needs to be set to access \\(b\\) batch_size
    progress_bar : bool
        whether to show a progress bar during the fitting process.
    """
    # Set model to evaluation mode
    self.model.eval()

    data = next(iter(train_loader))
    with torch.no_grad():
        if isinstance(data, MutableMapping):  # To support Huggingface dataset
            if "backpack" in self._backend_cls.__name__.lower():
                raise ValueError(
                    "Currently BackPACK backend is not supported "
                    + "for custom models with non-tensor inputs "
                    + "(https://github.com/pytorch/functorch/issues/159). Consider "
                    + "using AsdlGGN backend instead."
                )

            out = self.model(data)
        else:
            X = data[0]
            try:
                out = self.model(X[:1].to(self._device))
            except (TypeError, AttributeError):
                out = self.model(X.to(self._device))
    self.n_outputs = out.shape[-1]
    setattr(self.model, "output_size", self.n_outputs)
    self.batch_size = train_loader.batch_size

    if (
        self.likelihood == "regression"
        and self.n_outputs > 1
        and self.independent_outputs
    ):
        warnings.warn(
            "Using FunctionalLaplace with the diagonal approximation of a GP kernel is not recommended "
            "in the case of multivariate regression. Predictive variance will likely be overestimated."
        )

    N = len(train_loader.dataset)
    self.n_data = N

    assert (
        self.n_subset <= N
    ), "`num_data` must be less than or equal to the original number of data points."

    train_loader = self._get_SoD_data_loader(train_loader)
    self.train_loader = train_loader
    self._prior_factor_sod = self.n_subset / self.n_data

    self._init_K_MM()
    self._init_Sigma_inv()

    f, lambdas, mu = [], [], []

    if progress_bar:
        loader = enumerate(tqdm.tqdm(train_loader, desc="Fitting"))
    else:
        loader = enumerate(train_loader)

    for i, data in loader:
        if isinstance(data, MutableMapping):  # To support Huggingface dataset
            X, y = data, data[self.dict_key_y].to(self._device)
        else:
            X, y = data
            X, y = X.to(self._device), y.to(self._device)

        Js_batch, f_batch = self._jacobians(X, enable_backprop=False)

        if self.likelihood == Likelihood.REGRESSION and y.ndim != out.ndim:
            raise ValueError(
                f"The model's output has {out.ndim} dims but "
                f"the target has {y.ndim} dims."
            )

        with torch.no_grad():
            loss_batch = self.backend.factor * self.backend.lossfunc(f_batch, y)

        if self.likelihood == Likelihood.REGRESSION:
            b, C = f_batch.shape
            lambdas_batch = torch.unsqueeze(torch.eye(C), 0).repeat(b, 1, 1)
        else:
            # second derivative of log lik is diag(p) - pp^T
            ps = torch.softmax(f_batch, dim=-1)
            lambdas_batch = torch.diag_embed(ps) - torch.einsum(
                "mk,mc->mck", ps, ps
            )

        self.loss += loss_batch
        lambdas.append(lambdas_batch)
        f.append(f_batch)
        mu.append(
            self._mean_scatter_term_batch(Js_batch, f_batch, y)
        )  # needed for marginal likelihood
        for j, (X2, _) in enumerate(train_loader):
            if j >= i:
                X2 = X2.to(self._device)
                K_batch = self._kernel_batch(Js_batch, X2)
                self._store_K_batch(K_batch, i, j)

    self.L = self._build_L(lambdas)
    self.mu = torch.cat(mu, dim=0)
    self._build_Sigma_inv()
    self._fitted = True

__call__ #

__call__(x: Tensor | MutableMapping, pred_type: PredType | str = PredType.GP, joint: bool = False, link_approx: LinkApprox | str = LinkApprox.PROBIT, n_samples: int = 100, diagonal_output: bool = False, generator: Generator | None = None, fitting: bool = False, **model_kwargs: dict[str, Any]) -> Tensor | tuple[Tensor, Tensor]

Compute the posterior predictive on input data x.

Parameters:

  • x (Tensor or MutableMapping) –

    (batch_size, input_shape) if tensor. If MutableMapping, must contain the said tensor.

  • pred_type ('gp', default: 'gp' ) –

    type of posterior predictive, linearized GLM predictive (GP). The GP predictive is consistent with the curvature approximations used here.

  • link_approx (('mc', 'probit', 'bridge', 'bridge_norm'), default: 'mc' ) –

    how to approximate the classification link function for the 'glm'.

  • joint (bool, default: False ) –

    Whether to output a joint predictive distribution in regression with pred_type='glm'. If set to True, the predictive distribution has the same form as GP posterior, i.e. N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]). If False, then only outputs the marginal predictive distribution. Only available for regression and GLM predictive.

  • n_samples (int, default: 100 ) –

    number of samples for link_approx='mc'.

  • diagonal_output (bool, default: False ) –

    whether to use a diagonalized posterior predictive on the outputs. Only works for link_approx='mc'.

  • generator (Generator, default: None ) –

    random number generator to control the samples (if sampling used).

  • fitting (bool, default: False ) –

    whether or not this predictive call is done during fitting. Only useful for reward modeling: the likelihood is set to "regression" when False and "classification" when True.

Returns:

  • predictive ( Tensor or Tuple[Tensor] ) –

    For likelihood='classification', a torch.Tensor is returned with a distribution over classes (similar to a Softmax). For likelihood='regression', a tuple of torch.Tensor is returned with the mean and the predictive variance. For likelihood='regression' and joint=True, a tuple of torch.Tensor is returned with the mean and the predictive covariance.

Source code in laplace/baselaplace.py
def __call__(
    self,
    x: torch.Tensor | MutableMapping,
    pred_type: PredType | str = PredType.GP,
    joint: bool = False,
    link_approx: LinkApprox | str = LinkApprox.PROBIT,
    n_samples: int = 100,
    diagonal_output: bool = False,
    generator: torch.Generator | None = None,
    fitting: bool = False,
    **model_kwargs: dict[str, Any],
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
    """Compute the posterior predictive on input data `x`.

    Parameters
    ----------
    x : torch.Tensor or MutableMapping
        `(batch_size, input_shape)` if tensor. If MutableMapping, must contain
        the said tensor.

    pred_type : {'gp'}, default='gp'
        type of posterior predictive, linearized GLM predictive (GP).
        The GP predictive is consistent with
        the curvature approximations used here.

    link_approx : {'mc', 'probit', 'bridge', 'bridge_norm'}
        how to approximate the classification link function for the `'glm'`.

    joint : bool
        Whether to output a joint predictive distribution in regression with
        `pred_type='glm'`. If set to `True`, the predictive distribution
        has the same form as GP posterior, i.e. N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]).
        If `False`, then only outputs the marginal predictive distribution.
        Only available for regression and GLM predictive.

    n_samples : int
        number of samples for `link_approx='mc'`.

    diagonal_output : bool
        whether to use a diagonalized posterior predictive on the outputs.
        Only works for `link_approx='mc'`.

    generator : torch.Generator, optional
        random number generator to control the samples (if sampling used).

    fitting : bool, default=False
        whether or not this predictive call is done during fitting. Only useful for
        reward modeling: the likelihood is set to `"regression"` when `False` and
        `"classification"` when `True`.

    Returns
    -------
    predictive: torch.Tensor or Tuple[torch.Tensor]
        For `likelihood='classification'`, a torch.Tensor is returned with
        a distribution over classes (similar to a Softmax).
        For `likelihood='regression'`, a tuple of torch.Tensor is returned
        with the mean and the predictive variance.
        For `likelihood='regression'` and `joint=True`, a tuple of torch.Tensor
        is returned with the mean and the predictive covariance.
    """
    if self._fitted is False:
        raise RuntimeError(
            "Functional Laplace has not been fitted to any "
            + "training dataset. Please call .fit method."
        )

    if self._recompute_Sigma is True:
        warnings.warn(
            "The prior precision has been changed since fit. "
            + "Re-compututing its value..."
        )
        self._build_Sigma_inv()

    if pred_type != PredType.GP:
        raise ValueError("Only gp supported as prediction types.")

    if link_approx not in [la for la in LinkApprox]:
        raise ValueError(f"Unsupported link approximation {link_approx}.")

    if generator is not None:
        if (
            not isinstance(generator, torch.Generator)
            or generator.device != x.device
        ):
            raise ValueError("Invalid random generator (check type and device).")

    likelihood = self.likelihood
    if likelihood == Likelihood.REWARD_MODELING:
        likelihood = Likelihood.CLASSIFICATION if fitting else Likelihood.REGRESSION

    return self._glm_forward_call(
        x, likelihood, joint, link_approx, n_samples, diagonal_output
    )

predictive_samples #

predictive_samples(x: Tensor | MutableMapping[str, Tensor | Any], pred_type: PredType | str = PredType.GLM, n_samples: int = 100, diagonal_output: bool = False, generator: Generator | None = None) -> Tensor

Sample from the posterior predictive on input data x. Can be used, for example, for Thompson sampling.

Parameters:

  • x (Tensor or MutableMapping) –

    input data (batch_size, input_shape)

  • pred_type ('glm', default: 'glm' ) –

    type of posterior predictive, linearized GLM predictive.

  • n_samples (int, default: 100 ) –

    number of samples

  • diagonal_output (bool, default: False ) –

    whether to use a diagonalized glm posterior predictive on the outputs. Only applies when pred_type='glm'.

  • generator (Generator, default: None ) –

    random number generator to control the samples (if sampling used)

Returns:

  • samples ( Tensor ) –

    samples (n_samples, batch_size, output_shape)

Source code in laplace/baselaplace.py
def predictive_samples(
    self,
    x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
    pred_type: PredType | str = PredType.GLM,
    n_samples: int = 100,
    diagonal_output: bool = False,
    generator: torch.Generator | None = None,
) -> torch.Tensor:
    """Sample from the posterior predictive on input data `x`.
    Can be used, for example, for Thompson sampling.

    Parameters
    ----------
    x : torch.Tensor or MutableMapping
        input data `(batch_size, input_shape)`

    pred_type : {'glm'}, default='glm'
        type of posterior predictive, linearized GLM predictive.

    n_samples : int
        number of samples

    diagonal_output : bool
        whether to use a diagonalized glm posterior predictive on the outputs.
        Only applies when `pred_type='glm'`.

    generator : torch.Generator, optional
        random number generator to control the samples (if sampling used)

    Returns
    -------
    samples : torch.Tensor
        samples `(n_samples, batch_size, output_shape)`
    """
    if pred_type not in PredType.__members__.values():
        raise ValueError("Only glm  supported as prediction type.")

    f_mu, f_var = self._glm_predictive_distribution(x)
    return self._glm_predictive_samples(
        f_mu, f_var, n_samples, diagonal_output, generator
    )

functional_variance #

functional_variance(Js_star: Tensor) -> Tensor

GP posterior variance:

\[ k_{**} - K_{*M} (K_{MM}+ L_{MM}^{-1})^{-1} K_{M*} \]

Parameters:

  • Js_star (torch.Tensor of shape (N*, C, P)) –
      Jacobians of test data points
    

Returns:

  • f_var ( torch.Tensor of shape (N*,C, C) ) –

    Contains the posterior variances of N* testing points.

Source code in laplace/baselaplace.py
def functional_variance(self, Js_star: torch.Tensor) -> torch.Tensor:
    """GP posterior variance:

    $$
        k_{**} - K_{*M} (K_{MM}+ L_{MM}^{-1})^{-1} K_{M*}
    $$

    Parameters
    ----------
    Js_star : torch.Tensor of shape (N*, C, P)
              Jacobians of test data points

    Returns
    -------
    f_var : torch.Tensor of shape (N*,C, C)
            Contains the posterior variances of N* testing points.
    """
    # Compute K_{**}
    K_star = self.gp_kernel_prior_variance * self._kernel_star(Js_star)

    # Compute K_{*M}
    K_M_star = []
    for X_batch, _ in self.train_loader:
        K_M_star_batch = self.gp_kernel_prior_variance * self._kernel_batch_star(
            Js_star, X_batch.to(self._device)
        )
        K_M_star.append(K_M_star_batch)
        del X_batch

    # Build_K_star_M computes K_{*M} (K_{MM}+ L_{MM}^{-1})^{-1} K_{M*}
    f_var = K_star - self._build_K_star_M(K_M_star)

    # If the considered kernel is diagonal, embed the covariances.
    # from (N*, C) -> (N*, C, C)
    if self.independent_outputs:
        f_var = torch.diag_embed(f_var)

    return f_var

functional_covariance #

functional_covariance(Js_star: Tensor) -> Tensor

GP posterior covariance:

\[ k_{**} - K_{*M} (K_{MM}+ L_{MM}^{-1})^{-1} K_{M*} \]

Parameters:

  • Js_star (torch.Tensor of shape (N*, C, P)) –
      Jacobians of test data points
    

Returns:

  • f_var ( torch.Tensor of shape (N*xC, N*xC) ) –

    Contains the posterior covariances of N* testing points.

Source code in laplace/baselaplace.py
def functional_covariance(self, Js_star: torch.Tensor) -> torch.Tensor:
    """GP posterior covariance:

    $$
        k_{**} - K_{*M} (K_{MM}+ L_{MM}^{-1})^{-1} K_{M*}
    $$

    Parameters
    ----------
    Js_star : torch.Tensor of shape (N*, C, P)
              Jacobians of test data points

    Returns
    -------
    f_var : torch.Tensor of shape (N*xC, N*xC)
            Contains the posterior covariances of N* testing points.
    """
    # Compute K_{**}
    K_star = self.gp_kernel_prior_variance * self._kernel_star(Js_star, joint=True)

    # Compute K_{*M}
    K_M_star = []
    for X_batch, _ in self.train_loader:
        K_M_star_batch = self.gp_kernel_prior_variance * self._kernel_batch_star(
            Js_star, X_batch.to(self._device)
        )
        K_M_star.append(K_M_star_batch)
        del X_batch

    # Build_K_star_M computes K_{*M} (K_{MM}+ L_{MM}^{-1})^{-1} K_{M*}
    f_var = K_star - self._build_K_star_M(K_M_star, joint=True)

    # If the considered kernel is diagonal, embed the covariances.
    # from (N*, N*, C) -> (N*, N*, C, C)
    if self.independent_outputs:
        f_var = torch.diag_embed(f_var)

    # Reshape from (N*, N*, C, C) to (N*xC, N*xC)
    f_var = f_var.permute(0, 2, 1, 3).flatten(0, 1).flatten(1, 2)

    return f_var

_build_K_star_M #

_build_K_star_M(K_M_star: Tensor, joint: bool = False) -> Tensor

Computes K_{M} (K_{MM}+ L_{MM}^{-1})^{-1} K_{M} given K_{M*}.

Parameters:

  • K_M_star (list of torch.Tensor) –
       Contains K_{M*}. Tensors have shape (N_test, C, C)
       or (N_test, C) for diagonal kernel.
    
  • joint (boolean, default: False ) –
    Wether to compute cross covariances or not.
    

Returns:

  • torch.tensor of shape (N_test, N_test, C) for joint diagonal,
  • (N_test, C) for non-joint diagonal, (N_test, N_test, C, C) for
  • joint non-diagonal and (N_test, C, C) for non-joint non-diagonal.
Source code in laplace/baselaplace.py
def _build_K_star_M(
    self, K_M_star: torch.Tensor, joint: bool = False
) -> torch.Tensor:
    """Computes K_{*M} (K_{MM}+ L_{MM}^{-1})^{-1} K_{M*} given K_{M*}.

    Parameters
    ----------
    K_M_star : list of torch.Tensor
               Contains K_{M*}. Tensors have shape (N_test, C, C)
               or (N_test, C) for diagonal kernel.

    joint : boolean
            Wether to compute cross covariances or not.

    Returns
    -------
    torch.tensor of shape (N_test, N_test, C) for joint diagonal,
    (N_test, C) for non-joint diagonal, (N_test, N_test, C, C) for
    joint non-diagonal and (N_test, C, C) for non-joint non-diagonal.
    """
    # Shape (N_test, N, C, C) or (N_test, N, C) for diagonal
    K_M_star = torch.cat(K_M_star, dim=1)

    if self.independent_outputs:
        prods = []
        for c in range(self.n_outputs):
            # Compute K_{*M}L^{-1}
            v = torch.squeeze(
                torch.linalg.solve(
                    self.Sigma_inv[c], K_M_star[:, :, c].unsqueeze(2)
                ),
                2,
            )
            if joint:
                prod = torch.einsum("bm,am->ba", v, v)
            else:
                prod = torch.einsum("bm,bm->b", v, v)
            prods.append(prod.unsqueeze(1))
        prods = torch.cat(prods, dim=-1)
        return prods
    else:
        # Reshape to (N_test, NxC, C) or (N_test, N, C)
        K_M_star = K_M_star.reshape(K_M_star.shape[0], -1, K_M_star.shape[-1])
        # Compute K_{*M}L^{-1}
        v = torch.linalg.solve(self.Sigma_inv, K_M_star)
        if joint:
            return torch.einsum("acm,bcn->abmn", v, v)
        else:
            return torch.einsum("bcm,bcn->bmn", v, v)

optimize_prior_precision #

optimize_prior_precision(pred_type: PredType | str = PredType.GP, method: TuningMethod | str = TuningMethod.MARGLIK, n_steps: int = 100, lr: float = 0.1, init_prior_prec: float | Tensor = 1.0, prior_structure: PriorStructure | str = PriorStructure.SCALAR, val_loader: DataLoader | None = None, loss: Metric | Callable[[Tensor], Tensor | float] | None = None, log_prior_prec_min: float = -4, log_prior_prec_max: float = 4, grid_size: int = 100, link_approx: LinkApprox | str = LinkApprox.PROBIT, n_samples: int = 100, verbose: bool = False, progress_bar: bool = False) -> None

optimize_prior_precision_base from BaseLaplace with pred_type='gp'

Source code in laplace/baselaplace.py
def optimize_prior_precision(
    self,
    pred_type: PredType | str = PredType.GP,
    method: TuningMethod | str = TuningMethod.MARGLIK,
    n_steps: int = 100,
    lr: float = 1e-1,
    init_prior_prec: float | torch.Tensor = 1.0,
    prior_structure: PriorStructure | str = PriorStructure.SCALAR,
    val_loader: DataLoader | None = None,
    loss: torchmetrics.Metric
    | Callable[[torch.Tensor], torch.Tensor | float]
    | None = None,
    log_prior_prec_min: float = -4,
    log_prior_prec_max: float = 4,
    grid_size: int = 100,
    link_approx: LinkApprox | str = LinkApprox.PROBIT,
    n_samples: int = 100,
    verbose: bool = False,
    progress_bar: bool = False,
) -> None:
    """`optimize_prior_precision_base` from `BaseLaplace` with `pred_type='gp'`"""
    assert pred_type == PredType.GP  # only gp supported
    assert prior_structure == "scalar"  # only isotropic gaussian prior supported
    if method == "marglik":
        warnings.warn(
            "Use of method='marglik' in case of FunctionalLaplace is discouraged, rather use method='CV'."
        )
    super().optimize_prior_precision(
        pred_type,
        method,
        n_steps,
        lr,
        init_prior_prec,
        prior_structure,
        val_loader,
        loss,
        log_prior_prec_min,
        log_prior_prec_max,
        grid_size,
        link_approx,
        n_samples,
        verbose,
        progress_bar,
    )
    self._build_Sigma_inv()

_kernel_batch #

_kernel_batch(jacobians: Tensor, batch: Tensor) -> Tensor

Compute K_bb, which is part of K_MM kernel matrix.

Parameters:

  • jacobians (Tensor(b, C, P)) –
  • batch (Tensor(b, C)) –

Returns:

  • kernel ( tensor ) –

    K_bb with shape (b * C, b * C)

Source code in laplace/baselaplace.py
def _kernel_batch(
    self, jacobians: torch.Tensor, batch: torch.Tensor
) -> torch.Tensor:
    """Compute K_bb, which is part of K_MM kernel matrix.

    Parameters
    ----------
    jacobians : torch.Tensor (b, C, P)
    batch : torch.Tensor (b, C)

    Returns
    -------
    kernel : torch.tensor
        K_bb with shape (b * C, b * C)
    """
    jacobians_2, _ = self._jacobians(batch)
    P = jacobians.shape[-1]  # nr model params
    if self.independent_outputs:
        kernel = torch.empty(
            (jacobians.shape[0], jacobians_2.shape[0], self.n_outputs),
            device=jacobians.device,
        )
        for c in range(self.n_outputs):
            kernel[:, :, c] = torch.einsum(
                "bp,ep->be", jacobians[:, c, :], jacobians_2[:, c, :]
            )
    else:
        kernel = torch.einsum(
            "ap,bp->ab", jacobians.reshape(-1, P), jacobians_2.reshape(-1, P)
        )
    del jacobians_2
    return kernel

_kernel_star #

_kernel_star(jacobians: Tensor, joint: bool = False) -> Tensor

Compute K_star_star kernel matrix.

Parameters:

  • jacobians (Tensor(b, C, P)) –

Returns:

  • kernel ( tensor ) –

    K_star with shape (b, C, C)

Source code in laplace/baselaplace.py
def _kernel_star(
    self, jacobians: torch.Tensor, joint: bool = False
) -> torch.Tensor:
    """Compute K_star_star kernel matrix.

    Parameters
    ----------
    jacobians : torch.Tensor (b, C, P)

    Returns
    -------
    kernel : torch.tensor
        K_star with shape (b, C, C)

    """
    if joint:
        if self.independent_outputs:
            kernel = torch.einsum("acp,bcp->abcc", jacobians, jacobians)
        else:
            kernel = torch.einsum("acp,bep->abce", jacobians, jacobians)

    else:
        if self.independent_outputs:
            kernel = torch.empty(
                (jacobians.shape[0], self.n_outputs), device=jacobians.device
            )
            for c in range(self.n_outputs):
                kernel[:, c] = torch.norm(jacobians[:, c, :], dim=1) ** 2
        else:
            kernel = torch.einsum("bcp,bep->bce", jacobians, jacobians)
    return kernel

_kernel_batch_star #

_kernel_batch_star(jacobians: Tensor, batch: Tensor) -> Tensor

Compute K_b_star, which is a part of K_M_star kernel matrix.

Parameters:

  • jacobians (Tensor(b1, C, P)) –
  • batch (Tensor(b2, C)) –

Returns:

  • kernel ( tensor ) –

    K_batch_star with shape (b1, b2, C, C)

Source code in laplace/baselaplace.py
def _kernel_batch_star(
    self, jacobians: torch.Tensor, batch: torch.Tensor
) -> torch.Tensor:
    """Compute K_b_star, which is a part of K_M_star kernel matrix.

    Parameters
    ----------
    jacobians : torch.Tensor (b1, C, P)
    batch : torch.Tensor (b2, C)

    Returns
    -------
    kernel : torch.tensor
        K_batch_star with shape (b1, b2, C, C)
    """
    jacobians_2, _ = self._jacobians(batch)
    if self.independent_outputs:
        kernel = torch.empty(
            (jacobians.shape[0], jacobians_2.shape[0], self.n_outputs),
            device=jacobians.device,
        )
        for c in range(self.n_outputs):
            kernel[:, :, c] = torch.einsum(
                "bp,ep->be", jacobians[:, c, :], jacobians_2[:, c, :]
            )
    else:
        kernel = torch.einsum("bcp,dep->bdce", jacobians, jacobians_2)
    return kernel

_jacobians #

_jacobians(X: Tensor, enable_backprop: bool = None) -> tuple

A wrapper function to compute jacobians - this enables reusing same kernel methods (kernel_batch etc.) in FunctionalLaplace and FunctionalLLLaplace by simply overwriting this method instead of all kernel methods.

Source code in laplace/baselaplace.py
def _jacobians(self, X: torch.Tensor, enable_backprop: bool = None) -> tuple:
    """A wrapper function to compute jacobians - this enables reusing same
    kernel methods (kernel_batch etc.) in FunctionalLaplace and FunctionalLLLaplace
    by simply overwriting this method instead of all kernel methods.
    """
    if enable_backprop is None:
        enable_backprop = self.enable_backprop
    return self.backend.jacobians(X, enable_backprop=enable_backprop)

_mean_scatter_term_batch #

_mean_scatter_term_batch(Js: Tensor, f: Tensor, y: Tensor)

Compute mean vector in the scatter term in the log marginal likelihood

See scatter_lml property above for the exact equations of mean vectors in scatter terms for both types of likelihood (regression, classification).

Parameters:

  • Js (tensor) –

    Jacobians (batch, output_shape, parameters)

  • f (tensor) –

    NN output (batch, output_shape)

  • y (Tensor) –

    data labels (batch, output_shape)

Returns:

  • mu ( tensor ) –

    K_batch_star with shape (batch, output_shape)

Source code in laplace/baselaplace.py
def _mean_scatter_term_batch(
    self, Js: torch.Tensor, f: torch.Tensor, y: torch.Tensor
):
    """Compute mean vector in the scatter term in the log marginal likelihood

    See `scatter_lml` property above for the exact equations of mean vectors in scatter terms for
    both types of likelihood (regression, classification).

    Parameters
    ----------
    Js : torch.tensor
          Jacobians (batch, output_shape, parameters)
    f : torch.tensor
          NN output (batch, output_shape)
    y: torch.tensor
          data labels (batch, output_shape)

    Returns
    -------
    mu : torch.tensor
        K_batch_star with shape (batch, output_shape)
    """
    if self.likelihood == Likelihood.REGRESSION:
        return y - (f + torch.einsum("bcp,p->bc", Js, self.prior_mean - self.mean))
    elif self.likelihood == Likelihood.CLASSIFICATION:
        return -torch.einsum("bcp,p->bc", Js, self.prior_mean - self.mean)

log_marginal_likelihood #

log_marginal_likelihood(prior_precision: Tensor | None = None, sigma_noise: Tensor | None = None) -> Tensor

Compute the Laplace approximation to the log marginal likelihood. Requires that the Laplace approximation has been fit before. The resulting torch.Tensor is differentiable in prior_precision and sigma_noise if these have gradients enabled. By passing prior_precision or sigma_noise, the current value is overwritten. This is useful for iterating on the log marginal likelihood.

Parameters:

  • prior_precision (Tensor, default: None ) –

    prior precision if should be changed from current prior_precision value

  • sigma_noise (Tensor, default: None ) –

    observation noise standard deviation if should be changed

Returns:

  • log_marglik ( Tensor ) –
Source code in laplace/baselaplace.py
def log_marginal_likelihood(
    self,
    prior_precision: torch.Tensor | None = None,
    sigma_noise: torch.Tensor | None = None,
) -> torch.Tensor:
    """Compute the Laplace approximation to the log marginal likelihood.
    Requires that the Laplace approximation has been fit before.
    The resulting torch.Tensor is differentiable in `prior_precision` and
    `sigma_noise` if these have gradients enabled.
    By passing `prior_precision` or `sigma_noise`, the current value is
    overwritten. This is useful for iterating on the log marginal likelihood.

    Parameters
    ----------
    prior_precision : torch.Tensor, optional
        prior precision if should be changed from current `prior_precision` value
    sigma_noise : torch.Tensor, optional
        observation noise standard deviation if should be changed

    Returns
    -------
    log_marglik : torch.Tensor
    """
    # update prior precision (useful when iterating on marglik)
    if prior_precision is not None:
        self.prior_precision = prior_precision

    # update sigma_noise (useful when iterating on marglik)
    if sigma_noise is not None:
        if self.likelihood != Likelihood.REGRESSION:
            raise ValueError("Can only change sigma_noise for regression.")
        self.sigma_noise = sigma_noise

    return self.log_likelihood - 0.5 * (self.log_det_ratio + self.scatter)