Skip to content

laplace.baselaplace #

Classes:

  • BaseLaplace

    Baseclass for all Laplace approximations in this library.

BaseLaplace #

BaseLaplace(model: Module, likelihood: Likelihood | str, sigma_noise: float | Tensor = 1.0, prior_precision: float | Tensor = 1.0, prior_mean: float | Tensor = 0.0, temperature: float = 1.0, enable_backprop: bool = False, dict_key_x: str = 'input_ids', dict_key_y: str = 'labels', backend: type[CurvatureInterface] | None = None, backend_kwargs: dict[str, Any] | None = None, asdl_fisher_kwargs: dict[str, Any] | None = None)

Baseclass for all Laplace approximations in this library.

Parameters:

  • model #

    (Module) –
  • likelihood #

    (Likelihood or str in {'classification', 'regression', 'reward_modeling'}) –

    determines the log likelihood Hessian approximation. In the case of 'reward_modeling', it fits Laplace using the classification likelihood, then does prediction as in regression likelihood. The model needs to be defined accordingly: The forward pass during training takes x.shape == (batch_size, 2, dim) with y.shape = (batch_size,). Meanwhile, during evaluation x.shape == (batch_size, dim). Note that 'reward_modeling' only supports KronLaplace and DiagLaplace.

  • sigma_noise #

    (Tensor or float, default: 1 ) –

    observation noise for the regression setting; must be 1 for classification

  • prior_precision #

    (Tensor or float, default: 1 ) –

    prior precision of a Gaussian prior (= weight decay); can be scalar, per-layer, or diagonal in the most general case

  • prior_mean #

    (Tensor or float, default: 0 ) –

    prior mean of a Gaussian prior, useful for continual learning

  • temperature #

    (float, default: 1 ) –

    temperature of the likelihood; lower temperature leads to more concentrated posterior and vice versa.

  • enable_backprop #

    (bool, default: False ) –

    whether to enable backprop to the input x through the Laplace predictive. Useful for e.g. Bayesian optimization.

  • dict_key_x #

    (str, default: 'input_ids' ) –

    The dictionary key under which the input tensor x is stored. Only has effect when the model takes a MutableMapping as the input. Useful for Huggingface LLM models.

  • dict_key_y #

    (str, default: 'labels' ) –

    The dictionary key under which the target tensor y is stored. Only has effect when the model takes a MutableMapping as the input. Useful for Huggingface LLM models.

  • backend #

    (subclasses of `laplace.curvature.CurvatureInterface`, default: None ) –

    backend for access to curvature/Hessian approximations. Defaults to CurvlinopsGGN if None.

  • backend_kwargs #

    (dict, default: None ) –

    arguments passed to the backend on initialization, for example to set the number of MC samples for stochastic approximations.

  • asdl_fisher_kwargs #

    (dict, default: None ) –

    arguments passed to the ASDL backend specifically on initialization.

Methods:

Attributes:

  • log_likelihood (Tensor) –

    Compute log likelihood on the training data after .fit() has been called.

  • prior_precision_diag (Tensor) –

    Obtain the diagonal prior precision \(p_0\) constructed from either

Source code in laplace/baselaplace.py
def __init__(
    self,
    model: nn.Module,
    likelihood: Likelihood | str,
    sigma_noise: float | torch.Tensor = 1.0,
    prior_precision: float | torch.Tensor = 1.0,
    prior_mean: float | torch.Tensor = 0.0,
    temperature: float = 1.0,
    enable_backprop: bool = False,
    dict_key_x: str = "input_ids",
    dict_key_y: str = "labels",
    backend: type[CurvatureInterface] | None = None,
    backend_kwargs: dict[str, Any] | None = None,
    asdl_fisher_kwargs: dict[str, Any] | None = None,
) -> None:
    if likelihood not in [lik.value for lik in Likelihood]:
        raise ValueError(f"Invalid likelihood type {likelihood}")

    self.model: nn.Module = model
    self.likelihood: Likelihood | str = likelihood

    # Only do Laplace on params that require grad
    self.params: list[torch.Tensor] = []
    self.is_subset_params: bool = False
    for p in model.parameters():
        if p.requires_grad:
            self.params.append(p)
        else:
            self.is_subset_params = True

    self.n_params: int = sum(p.numel() for p in self.params)
    self.n_layers: int = len(self.params)
    self.prior_precision: float | torch.Tensor = prior_precision
    self.prior_mean: float | torch.Tensor = prior_mean
    if sigma_noise != 1 and likelihood != Likelihood.REGRESSION:
        raise ValueError("Sigma noise != 1 only available for regression.")

    self.sigma_noise: float | torch.Tensor = sigma_noise
    self.temperature: float = temperature
    self.enable_backprop: bool = enable_backprop

    # For models with dict-like inputs (e.g. Huggingface LLMs)
    self.dict_key_x = dict_key_x
    self.dict_key_y = dict_key_y

    if backend is None:
        backend = CurvlinopsGGN
    else:
        if self.is_subset_params and (
            "backpack" in backend.__name__.lower()
            or "asdfghjkl" in backend.__name__.lower()
        ):
            raise ValueError(
                "If some grad are switched off, the BackPACK and Asdfghjkl backends"
                " are not supported."
            )

    self._backend: CurvatureInterface | None = None
    self._backend_cls: type[CurvatureInterface] = backend
    self._backend_kwargs: dict[str, Any] = (
        dict() if backend_kwargs is None else backend_kwargs
    )
    self._asdl_fisher_kwargs: dict[str, Any] = (
        dict() if asdl_fisher_kwargs is None else asdl_fisher_kwargs
    )

    # log likelihood = g(loss)
    self.loss: float = 0.0
    self.n_outputs: int = 0
    self.n_data: int = 0

    # Declare attributes
    self._prior_mean: torch.Tensor
    self._prior_precision: torch.Tensor
    self._sigma_noise: torch.Tensor
    self._posterior_scale: torch.Tensor | None

log_likelihood #

log_likelihood: Tensor

Compute log likelihood on the training data after .fit() has been called. The log likelihood is computed on-demand based on the loss and, for example, the observation noise which makes it differentiable in the latter for iterative updates.

Returns:

  • log_likelihood ( Tensor ) –

prior_precision_diag #

prior_precision_diag: Tensor

Obtain the diagonal prior precision \(p_0\) constructed from either a scalar, layer-wise, or diagonal prior precision.

Returns:

  • prior_precision_diag ( Tensor ) –

optimize_prior_precision #

optimize_prior_precision(pred_type: PredType | str, method: TuningMethod | str = MARGLIK, n_steps: int = 100, lr: float = 0.1, init_prior_prec: float | Tensor = 1.0, prior_structure: PriorStructure | str = DIAG, val_loader: DataLoader | None = None, loss: Metric | Callable[[Tensor], Tensor | float] | None = None, log_prior_prec_min: float = -4, log_prior_prec_max: float = 4, grid_size: int = 100, link_approx: LinkApprox | str = PROBIT, n_samples: int = 100, verbose: bool = False, progress_bar: bool = False) -> None

Optimize the prior precision post-hoc using the method specified by the user.

Parameters:

  • pred_type #

    (PredType or str in {'glm', 'nn'}) –

    type of posterior predictive, linearized GLM predictive or neural network sampling predictiv. The GLM predictive is consistent with the curvature approximations used here.

  • method #

    (TuningMethod or str in {'marglik', 'gridsearch'}, default: PredType.MARGLIK ) –

    specifies how the prior precision should be optimized.

  • n_steps #

    (int, default: 100 ) –

    the number of gradient descent steps to take.

  • lr #

    (float, default: 1e-1 ) –

    the learning rate to use for gradient descent.

  • init_prior_prec #

    (float or tensor, default: 1.0 ) –

    initial prior precision before the first optimization step.

  • prior_structure #

    (PriorStructure or str in {'scalar', 'layerwise', 'diag'}, default: PriorStructure.SCALAR ) –

    if init_prior_prec is scalar, the prior precision is optimized with this structure. otherwise, the structure of init_prior_prec is maintained.

  • val_loader #

    (DataLoader, default: None ) –

    DataLoader for the validation set; each iterate is a training batch (X, y).

  • loss #

    (callable or Metric, default: None ) –

    loss function to use for CV. If callable, the loss is computed offline (memory intensive). If torchmetrics.Metric, running loss is computed (efficient). The default depends on the likelihood: RunningNLLMetric() for classification and reward modeling, running MeanSquaredError() for regression.

  • log_prior_prec_min #

    (float, default: -4 ) –

    lower bound of gridsearch interval.

  • log_prior_prec_max #

    (float, default: 4 ) –

    upper bound of gridsearch interval.

  • grid_size #

    (int, default: 100 ) –

    number of values to consider inside the gridsearch interval.

  • (LinkApprox or str in {'mc', 'probit', 'bridge'}, default: LinkApprox.PROBIT ) –

    how to approximate the classification link function for the 'glm'. For pred_type='nn', only 'mc' is possible.

  • n_samples #

    (int, default: 100 ) –

    number of samples for link_approx='mc'.

  • verbose #

    (bool, default: False ) –

    if true, the optimized prior precision will be printed (can be a large tensor if the prior has a diagonal covariance).

  • progress_bar #

    (bool, default: False ) –

    whether to show a progress bar; updated at every batch-Hessian computation. Useful for very large model and large amount of data, esp. when subset_of_weights='all'.

Source code in laplace/baselaplace.py
def optimize_prior_precision(
    self,
    pred_type: PredType | str,
    method: TuningMethod | str = TuningMethod.MARGLIK,
    n_steps: int = 100,
    lr: float = 1e-1,
    init_prior_prec: float | torch.Tensor = 1.0,
    prior_structure: PriorStructure | str = PriorStructure.DIAG,
    val_loader: DataLoader | None = None,
    loss: torchmetrics.Metric
    | Callable[[torch.Tensor], torch.Tensor | float]
    | None = None,
    log_prior_prec_min: float = -4,
    log_prior_prec_max: float = 4,
    grid_size: int = 100,
    link_approx: LinkApprox | str = LinkApprox.PROBIT,
    n_samples: int = 100,
    verbose: bool = False,
    progress_bar: bool = False,
) -> None:
    """Optimize the prior precision post-hoc using the `method`
    specified by the user.

    Parameters
    ----------
    pred_type : PredType or str in {'glm', 'nn'}
        type of posterior predictive, linearized GLM predictive or neural
        network sampling predictiv. The GLM predictive is consistent with the
        curvature approximations used here.
    method : TuningMethod or str in {'marglik', 'gridsearch'}, default=PredType.MARGLIK
        specifies how the prior precision should be optimized.
    n_steps : int, default=100
        the number of gradient descent steps to take.
    lr : float, default=1e-1
        the learning rate to use for gradient descent.
    init_prior_prec : float or tensor, default=1.0
        initial prior precision before the first optimization step.
    prior_structure : PriorStructure or str in {'scalar', 'layerwise', 'diag'}, default=PriorStructure.SCALAR
        if init_prior_prec is scalar, the prior precision is optimized with this structure.
        otherwise, the structure of init_prior_prec is maintained.
    val_loader : torch.data.utils.DataLoader, default=None
        DataLoader for the validation set; each iterate is a training batch (X, y).
    loss : callable or torchmetrics.Metric, default=None
        loss function to use for CV. If callable, the loss is computed offline (memory intensive).
        If torchmetrics.Metric, running loss is computed (efficient). The default
        depends on the likelihood: `RunningNLLMetric()` for classification and
        reward modeling, running `MeanSquaredError()` for regression.
    log_prior_prec_min : float, default=-4
        lower bound of gridsearch interval.
    log_prior_prec_max : float, default=4
        upper bound of gridsearch interval.
    grid_size : int, default=100
        number of values to consider inside the gridsearch interval.
    link_approx : LinkApprox or str in {'mc', 'probit', 'bridge'}, default=LinkApprox.PROBIT
        how to approximate the classification link function for the `'glm'`.
        For `pred_type='nn'`, only `'mc'` is possible.
    n_samples : int, default=100
        number of samples for `link_approx='mc'`.
    verbose : bool, default=False
        if true, the optimized prior precision will be printed
        (can be a large tensor if the prior has a diagonal covariance).
    progress_bar : bool, default=False
        whether to show a progress bar; updated at every batch-Hessian computation.
        Useful for very large model and large amount of data, esp. when `subset_of_weights='all'`.
    """
    likelihood = (
        Likelihood.CLASSIFICATION
        if self.likelihood == Likelihood.REWARD_MODELING
        else self.likelihood
    )

    if likelihood == Likelihood.CLASSIFICATION:
        warnings.warn(
            "By default `link_approx` is `probit`. Make sure to set it equals to "
            "the way you want to call `la(test_data, pred_type=..., link_approx=...)`."
        )

    if method == TuningMethod.MARGLIK:
        if val_loader is not None:
            warnings.warn(
                "`val_loader` will be ignored when `method` == 'marglik'. "
                "Do you mean to set `method = 'gridsearch'`?"
            )

        self.prior_precision = (
            init_prior_prec
            if isinstance(init_prior_prec, torch.Tensor)
            else torch.as_tensor(init_prior_prec)
        )

        if (
            len(self.prior_precision) == 1
            and prior_structure != PriorStructure.SCALAR
        ):
            self.prior_precision = fix_prior_prec_structure(
                self.prior_precision.item(),
                prior_structure,
                self.n_layers,
                self.n_params,
                self._device,
                self._dtype,
            )

        log_prior_prec = self.prior_precision.log()
        log_prior_prec.requires_grad = True
        optimizer = torch.optim.Adam([log_prior_prec], lr=lr)

        if progress_bar:
            pbar = tqdm.trange(n_steps)
            pbar.set_description("[Optimizing marginal likelihood]")
        else:
            pbar = range(n_steps)

        for _ in pbar:
            optimizer.zero_grad()
            prior_prec = log_prior_prec.exp()
            neg_log_marglik = -self.log_marginal_likelihood(
                prior_precision=prior_prec
            )
            neg_log_marglik.backward()
            optimizer.step()

        self.prior_precision = log_prior_prec.detach().exp()
    elif method == TuningMethod.GRIDSEARCH:
        if val_loader is None:
            raise ValueError("gridsearch requires a validation set DataLoader")

        interval = torch.logspace(log_prior_prec_min, log_prior_prec_max, grid_size)

        if loss is None:
            loss = (
                torchmetrics.MeanSquaredError(num_outputs=self.n_outputs).to(
                    self._device
                )
                if likelihood == Likelihood.REGRESSION
                else RunningNLLMetric().to(self._device)
            )

        self.prior_precision = self._gridsearch(
            loss,
            interval,
            val_loader,
            pred_type=pred_type,
            link_approx=link_approx,
            n_samples=n_samples,
            progress_bar=progress_bar,
        )
    else:
        raise ValueError("For now only marglik and gridsearch is implemented.")

    if verbose:
        print(f"Optimized prior precision is {self.prior_precision}.")

_glm_forward_call #

_glm_forward_call(x: Tensor | MutableMapping, likelihood: Likelihood | str, joint: bool = False, link_approx: LinkApprox | str = PROBIT, n_samples: int = 100, diagonal_output: bool = False) -> Tensor | tuple[Tensor, Tensor]

Compute the posterior predictive on input data x for "glm" pred type.

Parameters:

  • x #

    (Tensor or MutableMapping) –

    (batch_size, input_shape) if tensor. If MutableMapping, must contain the said tensor.

  • likelihood #

    (Likelihood or str in {'classification', 'regression', 'reward_modeling'}) –

    determines the log likelihood Hessian approximation.

  • (('mc', 'probit', 'bridge', 'bridge_norm'), default: 'mc' ) –

    how to approximate the classification link function for the 'glm'. For pred_type='nn', only 'mc' is possible.

  • joint #

    (bool, default: False ) –

    Whether to output a joint predictive distribution in regression with pred_type='glm'. If set to True, the predictive distribution has the same form as GP posterior, i.e. N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]). If False, then only outputs the marginal predictive distribution. Only available for regression and GLM predictive.

  • n_samples #

    (int, default: 100 ) –

    number of samples for link_approx='mc'.

  • diagonal_output #

    (bool, default: False ) –

    whether to use a diagonalized posterior predictive on the outputs. Only works for pred_type='glm' and link_approx='mc'.

Returns:

  • predictive ( Tensor or tuple[Tensor] ) –

    For likelihood='classification', a torch.Tensor is returned with a distribution over classes (similar to a Softmax). For likelihood='regression', a tuple of torch.Tensor is returned with the mean and the predictive variance. For likelihood='regression' and joint=True, a tuple of torch.Tensor is returned with the mean and the predictive covariance.

Source code in laplace/baselaplace.py
def _glm_forward_call(
    self,
    x: torch.Tensor | MutableMapping,
    likelihood: Likelihood | str,
    joint: bool = False,
    link_approx: LinkApprox | str = LinkApprox.PROBIT,
    n_samples: int = 100,
    diagonal_output: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
    """Compute the posterior predictive on input data `x` for "glm" pred type.

    Parameters
    ----------
    x : torch.Tensor or MutableMapping
        `(batch_size, input_shape)` if tensor. If MutableMapping, must contain
        the said tensor.

    likelihood : Likelihood or str in {'classification', 'regression', 'reward_modeling'}
        determines the log likelihood Hessian approximation.

    link_approx : {'mc', 'probit', 'bridge', 'bridge_norm'}
        how to approximate the classification link function for the `'glm'`.
        For `pred_type='nn'`, only 'mc' is possible.

    joint : bool
        Whether to output a joint predictive distribution in regression with
        `pred_type='glm'`. If set to `True`, the predictive distribution
        has the same form as GP posterior, i.e. N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]).
        If `False`, then only outputs the marginal predictive distribution.
        Only available for regression and GLM predictive.

    n_samples : int
        number of samples for `link_approx='mc'`.

    diagonal_output : bool
        whether to use a diagonalized posterior predictive on the outputs.
        Only works for `pred_type='glm'` and `link_approx='mc'`.

    Returns
    -------
    predictive: torch.Tensor or tuple[torch.Tensor]
        For `likelihood='classification'`, a torch.Tensor is returned with
        a distribution over classes (similar to a Softmax).
        For `likelihood='regression'`, a tuple of torch.Tensor is returned
        with the mean and the predictive variance.
        For `likelihood='regression'` and `joint=True`, a tuple of torch.Tensor
        is returned with the mean and the predictive covariance.
    """
    f_mu, f_var = self._glm_predictive_distribution(
        x, joint=joint and likelihood == Likelihood.REGRESSION
    )

    if likelihood == Likelihood.REGRESSION:
        if diagonal_output and not joint:
            f_var = torch.diagonal(f_var, dim1=-2, dim2=-1)
        return f_mu, f_var

    if link_approx == LinkApprox.MC:
        return self._glm_predictive_samples(
            f_mu,
            f_var,
            n_samples=n_samples,
            diagonal_output=diagonal_output,
        ).mean(dim=0)
    elif link_approx == LinkApprox.PROBIT:
        kappa = 1 / torch.sqrt(1.0 + np.pi / 8 * f_var.diagonal(dim1=1, dim2=2))
        return torch.softmax(kappa * f_mu, dim=-1)
    elif "bridge" in link_approx:
        # zero mean correction
        f_mu -= (
            f_var.sum(-1)
            * f_mu.sum(-1).reshape(-1, 1)
            / f_var.sum(dim=(1, 2)).reshape(-1, 1)
        )
        f_var -= torch.einsum(
            "bi,bj->bij", f_var.sum(-1), f_var.sum(-2)
        ) / f_var.sum(dim=(1, 2)).reshape(-1, 1, 1)

        # Laplace Bridge
        _, K = f_mu.size(0), f_mu.size(-1)
        f_var_diag = torch.diagonal(f_var, dim1=1, dim2=2)

        # optional: variance correction
        if link_approx == LinkApprox.BRIDGE_NORM:
            f_var_diag_mean = f_var_diag.mean(dim=1)
            f_var_diag_mean /= torch.as_tensor(
                [K / 2], device=self._device, dtype=self._dtype
            ).sqrt()
            f_mu /= f_var_diag_mean.sqrt().unsqueeze(-1)
            f_var_diag /= f_var_diag_mean.unsqueeze(-1)

        sum_exp = torch.exp(-f_mu).sum(dim=1).unsqueeze(-1)
        alpha = (1 - 2 / K + f_mu.exp() / K**2 * sum_exp) / f_var_diag
        return torch.nan_to_num(alpha / alpha.sum(dim=1).unsqueeze(-1), nan=1.0)
    else:
        raise ValueError(
            "Prediction path invalid. Check the likelihood, pred_type, link_approx combination!"
        )

_glm_functional_samples #

_glm_functional_samples(f_mu: Tensor, f_var: Tensor, n_samples: int, diagonal_output: bool = False, generator: Generator | None = None) -> Tensor

Sample from the posterior functional on input data x using "glm" prediction type.

Parameters:

  • f_mu #

    (Tensor or MutableMapping) –

    glm predictive mean (batch_size, output_shape)

  • f_var #

    (Tensor or MutableMapping) –

    glm predictive covariances (batch_size, output_shape, output_shape)

  • n_samples #

    (int) –

    number of samples

  • diagonal_output #

    (bool, default: False ) –

    whether to use a diagonalized glm posterior predictive on the outputs.

  • generator #

    (Generator, default: None ) –

    random number generator to control the samples (if sampling used)

Returns:

  • samples ( Tensor ) –

    samples (n_samples, batch_size, output_shape)

Source code in laplace/baselaplace.py
def _glm_functional_samples(
    self,
    f_mu: torch.Tensor,
    f_var: torch.Tensor,
    n_samples: int,
    diagonal_output: bool = False,
    generator: torch.Generator | None = None,
) -> torch.Tensor:
    """Sample from the posterior functional on input data `x` using "glm" prediction
    type.

    Parameters
    ----------
    f_mu : torch.Tensor or MutableMapping
        glm predictive mean `(batch_size, output_shape)`

    f_var : torch.Tensor or MutableMapping
        glm predictive covariances `(batch_size, output_shape, output_shape)`

    n_samples : int
        number of samples

    diagonal_output : bool
        whether to use a diagonalized glm posterior predictive on the outputs.

    generator : torch.Generator, optional
        random number generator to control the samples (if sampling used)

    Returns
    -------
    samples : torch.Tensor
        samples `(n_samples, batch_size, output_shape)`
    """
    assert f_var.shape == torch.Size([f_mu.shape[0], f_mu.shape[1], f_mu.shape[1]])

    if diagonal_output:
        f_var = torch.diagonal(f_var, dim1=1, dim2=2)

    return normal_samples(f_mu, f_var, n_samples, generator)

_glm_predictive_samples #

_glm_predictive_samples(f_mu: Tensor, f_var: Tensor, n_samples: int, diagonal_output: bool = False, generator: Generator | None = None) -> Tensor

Sample from the posterior predictive on input data x using "glm" prediction type. I.e., the inverse-link function correponding to the likelihood is applied on top of the functional sample.

Parameters:

  • f_mu #

    (Tensor or MutableMapping) –

    glm predictive mean (batch_size, output_shape)

  • f_var #

    (Tensor or MutableMapping) –

    glm predictive covariances (batch_size, output_shape, output_shape)

  • n_samples #

    (int) –

    number of samples

  • diagonal_output #

    (bool, default: False ) –

    whether to use a diagonalized glm posterior predictive on the outputs.

  • generator #

    (Generator, default: None ) –

    random number generator to control the samples (if sampling used)

Returns:

  • samples ( Tensor ) –

    samples (n_samples, batch_size, output_shape)

Source code in laplace/baselaplace.py
def _glm_predictive_samples(
    self,
    f_mu: torch.Tensor,
    f_var: torch.Tensor,
    n_samples: int,
    diagonal_output: bool = False,
    generator: torch.Generator | None = None,
) -> torch.Tensor:
    """Sample from the posterior predictive on input data `x` using "glm" prediction
    type. I.e., the inverse-link function correponding to the likelihood is applied
    on top of the functional sample.

    Parameters
    ----------
    f_mu : torch.Tensor or MutableMapping
        glm predictive mean `(batch_size, output_shape)`

    f_var : torch.Tensor or MutableMapping
        glm predictive covariances `(batch_size, output_shape, output_shape)`

    n_samples : int
        number of samples

    diagonal_output : bool
        whether to use a diagonalized glm posterior predictive on the outputs.

    generator : torch.Generator, optional
        random number generator to control the samples (if sampling used)

    Returns
    -------
    samples : torch.Tensor
        samples `(n_samples, batch_size, output_shape)`
    """
    f_samples = self._glm_functional_samples(
        f_mu, f_var, n_samples, diagonal_output, generator
    )

    if self.likelihood == Likelihood.REGRESSION:
        return f_samples
    else:
        return torch.softmax(f_samples, dim=-1)