Skip to content

laplace.curvature #

CurvatureInterface #

CurvatureInterface(model: Module, likelihood: Likelihood | str, last_layer: bool = False, subnetwork_indices: LongTensor | None = None, dict_key_x: str = 'input_ids', dict_key_y: str = 'labels')

Interface to access curvature for a model and corresponding likelihood. A CurvatureInterface must inherit from this baseclass and implement the necessary functions jacobians, full, kron, and diag. The interface might be extended in the future to account for other curvature structures, for example, a block-diagonal one.

Parameters:

  • model (torch.nn.Module or `laplace.utils.feature_extractor.FeatureExtractor`) –

    torch model (neural network)

  • likelihood (('classification', 'regression'), default: 'classification' ) –
  • last_layer (bool, default: False ) –

    only consider curvature of last layer

  • subnetwork_indices (LongTensor, default: None ) –

    indices of the vectorized model parameters that define the subnetwork to apply the Laplace approximation over

  • dict_key_x (str, default: 'input_ids' ) –

    The dictionary key under which the input tensor x is stored. Only has effect when the model takes a MutableMapping as the input. Useful for Huggingface LLM models.

  • dict_key_y (str, default: 'labels' ) –

    The dictionary key under which the target tensor y is stored. Only has effect when the model takes a MutableMapping as the input. Useful for Huggingface LLM models.

Attributes:

  • lossfunc (MSELoss or CrossEntropyLoss) –
  • factor (float) –

    conversion factor between torch losses and base likelihoods For example, \(\frac{1}{2}\) to get to \(\mathcal{N}(f, 1)\) from MSELoss.

Source code in laplace/curvature/curvature.py
def __init__(
    self,
    model: nn.Module,
    likelihood: Likelihood | str,
    last_layer: bool = False,
    subnetwork_indices: torch.LongTensor | None = None,
    dict_key_x: str = "input_ids",
    dict_key_y: str = "labels",
):
    assert likelihood in [Likelihood.REGRESSION, Likelihood.CLASSIFICATION]
    self.likelihood: Likelihood | str = likelihood
    self.model: nn.Module = model
    self.last_layer: bool = last_layer
    self.subnetwork_indices: torch.LongTensor | None = subnetwork_indices
    self.dict_key_x = dict_key_x
    self.dict_key_y = dict_key_y

    if likelihood == "regression":
        self.lossfunc: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = (
            MSELoss(reduction="sum")
        )
        self.factor: float = 0.5
    else:
        self.lossfunc: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = (
            CrossEntropyLoss(reduction="sum")
        )
        self.factor: float = 1.0

    self.params: list[nn.Parameter] = [
        p for p in self._model.parameters() if p.requires_grad
    ]
    self.params_dict: dict[str, nn.Parameter] = {
        k: v for k, v in self._model.named_parameters() if v.requires_grad
    }
    self.buffers_dict: dict[str, torch.Tensor] = {
        k: v for k, v in self.model.named_buffers()
    }

jacobians #

jacobians(x: Tensor | MutableMapping[str, Tensor | Any], enable_backprop: bool = False) -> tuple[Tensor, Tensor]

Compute Jacobians \(\nabla_{\theta} f(x;\theta)\) at current parameter \(\theta\), via torch.func.

Parameters:

  • x (Tensor) –

    input data (batch, input_shape) on compatible device with model.

  • enable_backprop (bool, default: = False ) –

    whether to enable backprop through the Js and f w.r.t. x

Returns:

  • Js ( Tensor ) –

    Jacobians (batch, parameters, outputs)

  • f ( Tensor ) –

    output function (batch, outputs)

Source code in laplace/curvature/curvature.py
def jacobians(
    self,
    x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
    enable_backprop: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Compute Jacobians \\(\\nabla_{\\theta} f(x;\\theta)\\) at current parameter \\(\\theta\\),
    via torch.func.

    Parameters
    ----------
    x : torch.Tensor
        input data `(batch, input_shape)` on compatible device with model.
    enable_backprop : bool, default = False
        whether to enable backprop through the Js and f w.r.t. x

    Returns
    -------
    Js : torch.Tensor
        Jacobians `(batch, parameters, outputs)`
    f : torch.Tensor
        output function `(batch, outputs)`
    """

    def model_fn_params_only(params_dict, buffers_dict):
        out = torch.func.functional_call(self.model, (params_dict, buffers_dict), x)
        return out, out

    Js, f = torch.func.jacrev(model_fn_params_only, has_aux=True)(
        self.params_dict, self.buffers_dict
    )

    # Concatenate over flattened parameters
    Js = [
        j.flatten(start_dim=-p.dim())
        for j, p in zip(Js.values(), self.params_dict.values())
    ]
    Js = torch.cat(Js, dim=-1)

    if self.subnetwork_indices is not None:
        Js = Js[:, :, self.subnetwork_indices]

    return (Js, f) if enable_backprop else (Js.detach(), f.detach())

last_layer_jacobians #

last_layer_jacobians(x: Tensor | MutableMapping[str, Tensor | Any], enable_backprop: bool = False) -> tuple[Tensor, Tensor]

Compute Jacobians \(\nabla_{\theta_\textrm{last}} f(x;\theta_\textrm{last})\) only at current last-layer parameter \(\theta_{\textrm{last}}\).

Parameters:

  • x (Tensor) –
  • enable_backprop (bool, default: False ) –

Returns:

  • Js ( Tensor ) –

    Jacobians (batch, outputs, last-layer-parameters)

  • f ( Tensor ) –

    output function (batch, outputs)

Source code in laplace/curvature/curvature.py
def last_layer_jacobians(
    self,
    x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
    enable_backprop: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Compute Jacobians \\(\\nabla_{\\theta_\\textrm{last}} f(x;\\theta_\\textrm{last})\\)
    only at current last-layer parameter \\(\\theta_{\\textrm{last}}\\).

    Parameters
    ----------
    x : torch.Tensor
    enable_backprop : bool, default=False

    Returns
    -------
    Js : torch.Tensor
        Jacobians `(batch, outputs, last-layer-parameters)`
    f : torch.Tensor
        output function `(batch, outputs)`
    """
    f, phi = self.model.forward_with_features(x)
    bsize = phi.shape[0]
    output_size = int(f.numel() / bsize)

    # calculate Jacobians using the feature vector 'phi'
    identity = (
        torch.eye(output_size, device=next(self.model.parameters()).device)
        .unsqueeze(0)
        .tile(bsize, 1, 1)
    )
    # Jacobians are batch x output x params
    Js = torch.einsum("kp,kij->kijp", phi, identity).reshape(bsize, output_size, -1)
    if self.model.last_layer.bias is not None:
        Js = torch.cat([Js, identity], dim=2)

    return (Js, f) if enable_backprop else (Js.detach(), f.detach())

gradients #

gradients(x: Tensor | MutableMapping[str, Tensor | Any], y: Tensor) -> tuple[Tensor, Tensor]

Compute batch gradients \(\nabla_\theta \ell(f(x;\theta, y)\) at current parameter \(\theta\).

Parameters:

  • x (Tensor) –

    input data (batch, input_shape) on compatible device with model.

  • y (Tensor) –

Returns:

  • Gs ( Tensor ) –

    gradients (batch, parameters)

  • loss ( Tensor ) –
Source code in laplace/curvature/curvature.py
def gradients(
    self, x: torch.Tensor | MutableMapping[str, torch.Tensor | Any], y: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
    """Compute batch gradients \\(\\nabla_\\theta \\ell(f(x;\\theta, y)\\) at
    current parameter \\(\\theta\\).

    Parameters
    ----------
    x : torch.Tensor
        input data `(batch, input_shape)` on compatible device with model.
    y : torch.Tensor

    Returns
    -------
    Gs : torch.Tensor
        gradients `(batch, parameters)`
    loss : torch.Tensor
    """

    def loss_single(x, y, params_dict, buffers_dict):
        """Compute the gradient for a single sample."""
        x, y = x.unsqueeze(0), y.unsqueeze(0)  # vmap removes the batch dimension
        output = torch.func.functional_call(
            self.model, (params_dict, buffers_dict), x
        )
        loss = torch.func.functional_call(self.lossfunc, {}, (output, y))
        return loss, loss

    grad_fn = torch.func.grad(loss_single, argnums=2, has_aux=True)
    batch_grad_fn = torch.func.vmap(grad_fn, in_dims=(0, 0, None, None))

    batch_grad, batch_loss = batch_grad_fn(
        x, y, self.params_dict, self.buffers_dict
    )
    Gs = torch.cat([bg.flatten(start_dim=1) for bg in batch_grad.values()], dim=1)

    if self.subnetwork_indices is not None:
        Gs = Gs[:, self.subnetwork_indices]

    loss = batch_loss.sum(0)

    return Gs, loss

full #

full(x: Tensor | MutableMapping[str, Tensor | Any], y: Tensor, **kwargs: dict[str, Any])

Compute a dense curvature (approximation) in the form of a \(P \times P\) matrix \(H\) with respect to parameters \(\theta \in \mathbb{R}^P\).

Parameters:

  • x (Tensor) –

    input data (batch, input_shape)

  • y (Tensor) –

    labels (batch, label_shape)

Returns:

  • loss ( Tensor ) –
  • H ( Tensor ) –

    Hessian approximation (parameters, parameters)

Source code in laplace/curvature/curvature.py
def full(
    self,
    x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
    y: torch.Tensor,
    **kwargs: dict[str, Any],
):
    """Compute a dense curvature (approximation) in the form of a \\(P \\times P\\) matrix
    \\(H\\) with respect to parameters \\(\\theta \\in \\mathbb{R}^P\\).

    Parameters
    ----------
    x : torch.Tensor
        input data `(batch, input_shape)`
    y : torch.Tensor
        labels `(batch, label_shape)`

    Returns
    -------
    loss : torch.Tensor
    H : torch.Tensor
        Hessian approximation `(parameters, parameters)`
    """
    raise NotImplementedError

kron #

kron(x: Tensor | MutableMapping[str, Tensor | Any], y: Tensor, N: int, **kwargs: dict[str, Any]) -> tuple[Tensor, Kron]

Compute a Kronecker factored curvature approximation (such as KFAC). The approximation to \(H\) takes the form of two Kronecker factors \(Q, H\), i.e., \(H \approx Q \otimes H\) for each Module in the neural network permitting such curvature. \(Q\) is quadratic in the input-dimension of a module \(p_{in} \times p_{in}\) and \(H\) in the output-dimension \(p_{out} \times p_{out}\).

Parameters:

  • x (Tensor) –

    input data (batch, input_shape)

  • y (Tensor) –

    labels (batch, label_shape)

  • N (int) –

    total number of data points

Returns:

  • loss ( Tensor ) –
  • H ( `laplace.utils.matrix.Kron` ) –

    Kronecker factored Hessian approximation.

Source code in laplace/curvature/curvature.py
def kron(
    self,
    x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
    y: torch.Tensor,
    N: int,
    **kwargs: dict[str, Any],
) -> tuple[torch.Tensor, Kron]:
    """Compute a Kronecker factored curvature approximation (such as KFAC).
    The approximation to \\(H\\) takes the form of two Kronecker factors \\(Q, H\\),
    i.e., \\(H \\approx Q \\otimes H\\) for each Module in the neural network permitting
    such curvature.
    \\(Q\\) is quadratic in the input-dimension of a module \\(p_{in} \\times p_{in}\\)
    and \\(H\\) in the output-dimension \\(p_{out} \\times p_{out}\\).

    Parameters
    ----------
    x : torch.Tensor
        input data `(batch, input_shape)`
    y : torch.Tensor
        labels `(batch, label_shape)`
    N : int
        total number of data points

    Returns
    -------
    loss : torch.Tensor
    H : `laplace.utils.matrix.Kron`
        Kronecker factored Hessian approximation.
    """
    raise NotImplementedError

diag #

diag(x: Tensor | MutableMapping[str, Tensor | Any], y: Tensor, **kwargs: dict[str, Any])

Compute a diagonal Hessian approximation to \(H\) and is represented as a vector of the dimensionality of parameters \(\theta\).

Parameters:

  • x (Tensor) –

    input data (batch, input_shape)

  • y (Tensor) –

    labels (batch, label_shape)

Returns:

  • loss ( Tensor ) –
  • H ( Tensor ) –

    vector representing the diagonal of H

Source code in laplace/curvature/curvature.py
def diag(
    self,
    x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
    y: torch.Tensor,
    **kwargs: dict[str, Any],
):
    """Compute a diagonal Hessian approximation to \\(H\\) and is represented as a
    vector of the dimensionality of parameters \\(\\theta\\).

    Parameters
    ----------
    x : torch.Tensor
        input data `(batch, input_shape)`
    y : torch.Tensor
        labels `(batch, label_shape)`

    Returns
    -------
    loss : torch.Tensor
    H : torch.Tensor
        vector representing the diagonal of H
    """
    raise NotImplementedError

GGNInterface #

GGNInterface(model: Module, likelihood: Likelihood | str, last_layer: bool = False, subnetwork_indices: LongTensor | None = None, dict_key_x: str = 'input_ids', dict_key_y: str = 'labels', stochastic: bool = False, num_samples: int = 1)

Bases: CurvatureInterface

Generalized Gauss-Newton or Fisher Curvature Interface. The GGN is equal to the Fisher information for the available likelihoods. In addition to CurvatureInterface, methods for Jacobians are required by subclasses.

Parameters:

  • model (torch.nn.Module or `laplace.utils.feature_extractor.FeatureExtractor`) –

    torch model (neural network)

  • likelihood (('classification', 'regression'), default: 'classification' ) –
  • last_layer (bool, default: False ) –

    only consider curvature of last layer

  • subnetwork_indices (Tensor, default: None ) –

    indices of the vectorized model parameters that define the subnetwork to apply the Laplace approximation over

  • dict_key_x (str, default: 'input_ids' ) –

    The dictionary key under which the input tensor x is stored. Only has effect when the model takes a MutableMapping as the input. Useful for Huggingface LLM models.

  • dict_key_y (str, default: 'labels' ) –

    The dictionary key under which the target tensor y is stored. Only has effect when the model takes a MutableMapping as the input. Useful for Huggingface LLM models.

  • stochastic (bool, default: False ) –

    Fisher if stochastic else GGN

  • num_samples (int, default: 1 ) –

    Number of samples used to approximate the stochastic Fisher

Source code in laplace/curvature/curvature.py
def __init__(
    self,
    model: nn.Module,
    likelihood: Likelihood | str,
    last_layer: bool = False,
    subnetwork_indices: torch.LongTensor | None = None,
    dict_key_x: str = "input_ids",
    dict_key_y: str = "labels",
    stochastic: bool = False,
    num_samples: int = 1,
) -> None:
    self.stochastic: bool = stochastic
    self.num_samples: int = num_samples

    super().__init__(
        model, likelihood, last_layer, subnetwork_indices, dict_key_x, dict_key_y
    )

jacobians #

jacobians(x: Tensor | MutableMapping[str, Tensor | Any], enable_backprop: bool = False) -> tuple[Tensor, Tensor]

Compute Jacobians \(\nabla_{\theta} f(x;\theta)\) at current parameter \(\theta\), via torch.func.

Parameters:

  • x (Tensor) –

    input data (batch, input_shape) on compatible device with model.

  • enable_backprop (bool, default: = False ) –

    whether to enable backprop through the Js and f w.r.t. x

Returns:

  • Js ( Tensor ) –

    Jacobians (batch, parameters, outputs)

  • f ( Tensor ) –

    output function (batch, outputs)

Source code in laplace/curvature/curvature.py
def jacobians(
    self,
    x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
    enable_backprop: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Compute Jacobians \\(\\nabla_{\\theta} f(x;\\theta)\\) at current parameter \\(\\theta\\),
    via torch.func.

    Parameters
    ----------
    x : torch.Tensor
        input data `(batch, input_shape)` on compatible device with model.
    enable_backprop : bool, default = False
        whether to enable backprop through the Js and f w.r.t. x

    Returns
    -------
    Js : torch.Tensor
        Jacobians `(batch, parameters, outputs)`
    f : torch.Tensor
        output function `(batch, outputs)`
    """

    def model_fn_params_only(params_dict, buffers_dict):
        out = torch.func.functional_call(self.model, (params_dict, buffers_dict), x)
        return out, out

    Js, f = torch.func.jacrev(model_fn_params_only, has_aux=True)(
        self.params_dict, self.buffers_dict
    )

    # Concatenate over flattened parameters
    Js = [
        j.flatten(start_dim=-p.dim())
        for j, p in zip(Js.values(), self.params_dict.values())
    ]
    Js = torch.cat(Js, dim=-1)

    if self.subnetwork_indices is not None:
        Js = Js[:, :, self.subnetwork_indices]

    return (Js, f) if enable_backprop else (Js.detach(), f.detach())

last_layer_jacobians #

last_layer_jacobians(x: Tensor | MutableMapping[str, Tensor | Any], enable_backprop: bool = False) -> tuple[Tensor, Tensor]

Compute Jacobians \(\nabla_{\theta_\textrm{last}} f(x;\theta_\textrm{last})\) only at current last-layer parameter \(\theta_{\textrm{last}}\).

Parameters:

  • x (Tensor) –
  • enable_backprop (bool, default: False ) –

Returns:

  • Js ( Tensor ) –

    Jacobians (batch, outputs, last-layer-parameters)

  • f ( Tensor ) –

    output function (batch, outputs)

Source code in laplace/curvature/curvature.py
def last_layer_jacobians(
    self,
    x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
    enable_backprop: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Compute Jacobians \\(\\nabla_{\\theta_\\textrm{last}} f(x;\\theta_\\textrm{last})\\)
    only at current last-layer parameter \\(\\theta_{\\textrm{last}}\\).

    Parameters
    ----------
    x : torch.Tensor
    enable_backprop : bool, default=False

    Returns
    -------
    Js : torch.Tensor
        Jacobians `(batch, outputs, last-layer-parameters)`
    f : torch.Tensor
        output function `(batch, outputs)`
    """
    f, phi = self.model.forward_with_features(x)
    bsize = phi.shape[0]
    output_size = int(f.numel() / bsize)

    # calculate Jacobians using the feature vector 'phi'
    identity = (
        torch.eye(output_size, device=next(self.model.parameters()).device)
        .unsqueeze(0)
        .tile(bsize, 1, 1)
    )
    # Jacobians are batch x output x params
    Js = torch.einsum("kp,kij->kijp", phi, identity).reshape(bsize, output_size, -1)
    if self.model.last_layer.bias is not None:
        Js = torch.cat([Js, identity], dim=2)

    return (Js, f) if enable_backprop else (Js.detach(), f.detach())

gradients #

gradients(x: Tensor | MutableMapping[str, Tensor | Any], y: Tensor) -> tuple[Tensor, Tensor]

Compute batch gradients \(\nabla_\theta \ell(f(x;\theta, y)\) at current parameter \(\theta\).

Parameters:

  • x (Tensor) –

    input data (batch, input_shape) on compatible device with model.

  • y (Tensor) –

Returns:

  • Gs ( Tensor ) –

    gradients (batch, parameters)

  • loss ( Tensor ) –
Source code in laplace/curvature/curvature.py
def gradients(
    self, x: torch.Tensor | MutableMapping[str, torch.Tensor | Any], y: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
    """Compute batch gradients \\(\\nabla_\\theta \\ell(f(x;\\theta, y)\\) at
    current parameter \\(\\theta\\).

    Parameters
    ----------
    x : torch.Tensor
        input data `(batch, input_shape)` on compatible device with model.
    y : torch.Tensor

    Returns
    -------
    Gs : torch.Tensor
        gradients `(batch, parameters)`
    loss : torch.Tensor
    """

    def loss_single(x, y, params_dict, buffers_dict):
        """Compute the gradient for a single sample."""
        x, y = x.unsqueeze(0), y.unsqueeze(0)  # vmap removes the batch dimension
        output = torch.func.functional_call(
            self.model, (params_dict, buffers_dict), x
        )
        loss = torch.func.functional_call(self.lossfunc, {}, (output, y))
        return loss, loss

    grad_fn = torch.func.grad(loss_single, argnums=2, has_aux=True)
    batch_grad_fn = torch.func.vmap(grad_fn, in_dims=(0, 0, None, None))

    batch_grad, batch_loss = batch_grad_fn(
        x, y, self.params_dict, self.buffers_dict
    )
    Gs = torch.cat([bg.flatten(start_dim=1) for bg in batch_grad.values()], dim=1)

    if self.subnetwork_indices is not None:
        Gs = Gs[:, self.subnetwork_indices]

    loss = batch_loss.sum(0)

    return Gs, loss

kron #

kron(x: Tensor | MutableMapping[str, Tensor | Any], y: Tensor, N: int, **kwargs: dict[str, Any]) -> tuple[Tensor, Kron]

Compute a Kronecker factored curvature approximation (such as KFAC). The approximation to \(H\) takes the form of two Kronecker factors \(Q, H\), i.e., \(H \approx Q \otimes H\) for each Module in the neural network permitting such curvature. \(Q\) is quadratic in the input-dimension of a module \(p_{in} \times p_{in}\) and \(H\) in the output-dimension \(p_{out} \times p_{out}\).

Parameters:

  • x (Tensor) –

    input data (batch, input_shape)

  • y (Tensor) –

    labels (batch, label_shape)

  • N (int) –

    total number of data points

Returns:

  • loss ( Tensor ) –
  • H ( `laplace.utils.matrix.Kron` ) –

    Kronecker factored Hessian approximation.

Source code in laplace/curvature/curvature.py
def kron(
    self,
    x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
    y: torch.Tensor,
    N: int,
    **kwargs: dict[str, Any],
) -> tuple[torch.Tensor, Kron]:
    """Compute a Kronecker factored curvature approximation (such as KFAC).
    The approximation to \\(H\\) takes the form of two Kronecker factors \\(Q, H\\),
    i.e., \\(H \\approx Q \\otimes H\\) for each Module in the neural network permitting
    such curvature.
    \\(Q\\) is quadratic in the input-dimension of a module \\(p_{in} \\times p_{in}\\)
    and \\(H\\) in the output-dimension \\(p_{out} \\times p_{out}\\).

    Parameters
    ----------
    x : torch.Tensor
        input data `(batch, input_shape)`
    y : torch.Tensor
        labels `(batch, label_shape)`
    N : int
        total number of data points

    Returns
    -------
    loss : torch.Tensor
    H : `laplace.utils.matrix.Kron`
        Kronecker factored Hessian approximation.
    """
    raise NotImplementedError

_get_mc_functional_fisher #

_get_mc_functional_fisher(f: Tensor) -> Tensor

Approximate the Fisher's middle matrix (expected outer product of the functional gradient) using MC integral with self.num_samples many samples.

Source code in laplace/curvature/curvature.py
def _get_mc_functional_fisher(self, f: torch.Tensor) -> torch.Tensor:
    """Approximate the Fisher's middle matrix (expected outer product of the functional gradient)
    using MC integral with `self.num_samples` many samples.
    """
    F = 0

    for _ in range(self.num_samples):
        if self.likelihood == "regression":
            y_sample = f + torch.randn(f.shape, device=f.device)  # N(y | f, 1)
            grad_sample = f - y_sample  # functional MSE grad
        else:  # classification with softmax
            y_sample = torch.distributions.Multinomial(logits=f).sample()
            # First functional derivative of the loglik is p - y
            p = torch.softmax(f, dim=-1)
            grad_sample = p - y_sample

        F += (
            1
            / self.num_samples
            * torch.einsum("bc,bk->bck", grad_sample, grad_sample)
        )

    return F

full #

full(x: Tensor | MutableMapping[str, Tensor | Any], y: Tensor, **kwargs: dict[str, Any]) -> tuple[Tensor, Tensor]

Compute the full GGN \(P \times P\) matrix as Hessian approximation \(H_{ggn}\) with respect to parameters \(\theta \in \mathbb{R}^P\). For last-layer, reduced to \(\theta_{last}\)

Parameters:

  • x (Tensor) –

    input data (batch, input_shape)

  • y (Tensor) –

    labels (batch, label_shape)

Returns:

  • loss ( Tensor ) –
  • H ( Tensor ) –

    GGN (parameters, parameters)

Source code in laplace/curvature/curvature.py
def full(
    self,
    x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
    y: torch.Tensor,
    **kwargs: dict[str, Any],
) -> tuple[torch.Tensor, torch.Tensor]:
    """Compute the full GGN \\(P \\times P\\) matrix as Hessian approximation
    \\(H_{ggn}\\) with respect to parameters \\(\\theta \\in \\mathbb{R}^P\\).
    For last-layer, reduced to \\(\\theta_{last}\\)

    Parameters
    ----------
    x : torch.Tensor
        input data `(batch, input_shape)`
    y : torch.Tensor
        labels `(batch, label_shape)`

    Returns
    -------
    loss : torch.Tensor
    H : torch.Tensor
        GGN `(parameters, parameters)`
    """
    Js, f = self.last_layer_jacobians(x) if self.last_layer else self.jacobians(x)
    H_lik = (
        self._get_mc_functional_fisher(f)
        if self.stochastic
        else self._get_functional_hessian(f)
    )

    if H_lik is not None:
        H = torch.einsum("bcp,bck,bkq->pq", Js, H_lik, Js)
    else:  # The case of exact GGN for regression
        H = torch.einsum("bcp,bcq->pq", Js, Js)
    loss = self.factor * self.lossfunc(f, y)

    return loss.detach(), H.detach()

EFInterface #

EFInterface(model: Module, likelihood: Likelihood | str, last_layer: bool = False, subnetwork_indices: LongTensor | None = None, dict_key_x: str = 'input_ids', dict_key_y: str = 'labels')

Bases: CurvatureInterface

Interface for Empirical Fisher as Hessian approximation. In addition to CurvatureInterface, methods for gradients are required by subclasses.

Parameters:

  • model (torch.nn.Module or `laplace.utils.feature_extractor.FeatureExtractor`) –

    torch model (neural network)

  • likelihood (('classification', 'regression'), default: 'classification' ) –
  • last_layer (bool, default: False ) –

    only consider curvature of last layer

  • subnetwork_indices (Tensor, default: None ) –

    indices of the vectorized model parameters that define the subnetwork to apply the Laplace approximation over

  • dict_key_x (str, default: 'input_ids' ) –

    The dictionary key under which the input tensor x is stored. Only has effect when the model takes a MutableMapping as the input. Useful for Huggingface LLM models.

  • dict_key_y (str, default: 'labels' ) –

    The dictionary key under which the target tensor y is stored. Only has effect when the model takes a MutableMapping as the input. Useful for Huggingface LLM models.

Attributes:

  • lossfunc (MSELoss or CrossEntropyLoss) –
  • factor (float) –

    conversion factor between torch losses and base likelihoods For example, \(\frac{1}{2}\) to get to \(\mathcal{N}(f, 1)\) from MSELoss.

Source code in laplace/curvature/curvature.py
def __init__(
    self,
    model: nn.Module,
    likelihood: Likelihood | str,
    last_layer: bool = False,
    subnetwork_indices: torch.LongTensor | None = None,
    dict_key_x: str = "input_ids",
    dict_key_y: str = "labels",
):
    assert likelihood in [Likelihood.REGRESSION, Likelihood.CLASSIFICATION]
    self.likelihood: Likelihood | str = likelihood
    self.model: nn.Module = model
    self.last_layer: bool = last_layer
    self.subnetwork_indices: torch.LongTensor | None = subnetwork_indices
    self.dict_key_x = dict_key_x
    self.dict_key_y = dict_key_y

    if likelihood == "regression":
        self.lossfunc: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = (
            MSELoss(reduction="sum")
        )
        self.factor: float = 0.5
    else:
        self.lossfunc: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = (
            CrossEntropyLoss(reduction="sum")
        )
        self.factor: float = 1.0

    self.params: list[nn.Parameter] = [
        p for p in self._model.parameters() if p.requires_grad
    ]
    self.params_dict: dict[str, nn.Parameter] = {
        k: v for k, v in self._model.named_parameters() if v.requires_grad
    }
    self.buffers_dict: dict[str, torch.Tensor] = {
        k: v for k, v in self.model.named_buffers()
    }

jacobians #

jacobians(x: Tensor | MutableMapping[str, Tensor | Any], enable_backprop: bool = False) -> tuple[Tensor, Tensor]

Compute Jacobians \(\nabla_{\theta} f(x;\theta)\) at current parameter \(\theta\), via torch.func.

Parameters:

  • x (Tensor) –

    input data (batch, input_shape) on compatible device with model.

  • enable_backprop (bool, default: = False ) –

    whether to enable backprop through the Js and f w.r.t. x

Returns:

  • Js ( Tensor ) –

    Jacobians (batch, parameters, outputs)

  • f ( Tensor ) –

    output function (batch, outputs)

Source code in laplace/curvature/curvature.py
def jacobians(
    self,
    x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
    enable_backprop: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Compute Jacobians \\(\\nabla_{\\theta} f(x;\\theta)\\) at current parameter \\(\\theta\\),
    via torch.func.

    Parameters
    ----------
    x : torch.Tensor
        input data `(batch, input_shape)` on compatible device with model.
    enable_backprop : bool, default = False
        whether to enable backprop through the Js and f w.r.t. x

    Returns
    -------
    Js : torch.Tensor
        Jacobians `(batch, parameters, outputs)`
    f : torch.Tensor
        output function `(batch, outputs)`
    """

    def model_fn_params_only(params_dict, buffers_dict):
        out = torch.func.functional_call(self.model, (params_dict, buffers_dict), x)
        return out, out

    Js, f = torch.func.jacrev(model_fn_params_only, has_aux=True)(
        self.params_dict, self.buffers_dict
    )

    # Concatenate over flattened parameters
    Js = [
        j.flatten(start_dim=-p.dim())
        for j, p in zip(Js.values(), self.params_dict.values())
    ]
    Js = torch.cat(Js, dim=-1)

    if self.subnetwork_indices is not None:
        Js = Js[:, :, self.subnetwork_indices]

    return (Js, f) if enable_backprop else (Js.detach(), f.detach())

last_layer_jacobians #

last_layer_jacobians(x: Tensor | MutableMapping[str, Tensor | Any], enable_backprop: bool = False) -> tuple[Tensor, Tensor]

Compute Jacobians \(\nabla_{\theta_\textrm{last}} f(x;\theta_\textrm{last})\) only at current last-layer parameter \(\theta_{\textrm{last}}\).

Parameters:

  • x (Tensor) –
  • enable_backprop (bool, default: False ) –

Returns:

  • Js ( Tensor ) –

    Jacobians (batch, outputs, last-layer-parameters)

  • f ( Tensor ) –

    output function (batch, outputs)

Source code in laplace/curvature/curvature.py
def last_layer_jacobians(
    self,
    x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
    enable_backprop: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Compute Jacobians \\(\\nabla_{\\theta_\\textrm{last}} f(x;\\theta_\\textrm{last})\\)
    only at current last-layer parameter \\(\\theta_{\\textrm{last}}\\).

    Parameters
    ----------
    x : torch.Tensor
    enable_backprop : bool, default=False

    Returns
    -------
    Js : torch.Tensor
        Jacobians `(batch, outputs, last-layer-parameters)`
    f : torch.Tensor
        output function `(batch, outputs)`
    """
    f, phi = self.model.forward_with_features(x)
    bsize = phi.shape[0]
    output_size = int(f.numel() / bsize)

    # calculate Jacobians using the feature vector 'phi'
    identity = (
        torch.eye(output_size, device=next(self.model.parameters()).device)
        .unsqueeze(0)
        .tile(bsize, 1, 1)
    )
    # Jacobians are batch x output x params
    Js = torch.einsum("kp,kij->kijp", phi, identity).reshape(bsize, output_size, -1)
    if self.model.last_layer.bias is not None:
        Js = torch.cat([Js, identity], dim=2)

    return (Js, f) if enable_backprop else (Js.detach(), f.detach())

gradients #

gradients(x: Tensor | MutableMapping[str, Tensor | Any], y: Tensor) -> tuple[Tensor, Tensor]

Compute batch gradients \(\nabla_\theta \ell(f(x;\theta, y)\) at current parameter \(\theta\).

Parameters:

  • x (Tensor) –

    input data (batch, input_shape) on compatible device with model.

  • y (Tensor) –

Returns:

  • Gs ( Tensor ) –

    gradients (batch, parameters)

  • loss ( Tensor ) –
Source code in laplace/curvature/curvature.py
def gradients(
    self, x: torch.Tensor | MutableMapping[str, torch.Tensor | Any], y: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
    """Compute batch gradients \\(\\nabla_\\theta \\ell(f(x;\\theta, y)\\) at
    current parameter \\(\\theta\\).

    Parameters
    ----------
    x : torch.Tensor
        input data `(batch, input_shape)` on compatible device with model.
    y : torch.Tensor

    Returns
    -------
    Gs : torch.Tensor
        gradients `(batch, parameters)`
    loss : torch.Tensor
    """

    def loss_single(x, y, params_dict, buffers_dict):
        """Compute the gradient for a single sample."""
        x, y = x.unsqueeze(0), y.unsqueeze(0)  # vmap removes the batch dimension
        output = torch.func.functional_call(
            self.model, (params_dict, buffers_dict), x
        )
        loss = torch.func.functional_call(self.lossfunc, {}, (output, y))
        return loss, loss

    grad_fn = torch.func.grad(loss_single, argnums=2, has_aux=True)
    batch_grad_fn = torch.func.vmap(grad_fn, in_dims=(0, 0, None, None))

    batch_grad, batch_loss = batch_grad_fn(
        x, y, self.params_dict, self.buffers_dict
    )
    Gs = torch.cat([bg.flatten(start_dim=1) for bg in batch_grad.values()], dim=1)

    if self.subnetwork_indices is not None:
        Gs = Gs[:, self.subnetwork_indices]

    loss = batch_loss.sum(0)

    return Gs, loss

kron #

kron(x: Tensor | MutableMapping[str, Tensor | Any], y: Tensor, N: int, **kwargs: dict[str, Any]) -> tuple[Tensor, Kron]

Compute a Kronecker factored curvature approximation (such as KFAC). The approximation to \(H\) takes the form of two Kronecker factors \(Q, H\), i.e., \(H \approx Q \otimes H\) for each Module in the neural network permitting such curvature. \(Q\) is quadratic in the input-dimension of a module \(p_{in} \times p_{in}\) and \(H\) in the output-dimension \(p_{out} \times p_{out}\).

Parameters:

  • x (Tensor) –

    input data (batch, input_shape)

  • y (Tensor) –

    labels (batch, label_shape)

  • N (int) –

    total number of data points

Returns:

  • loss ( Tensor ) –
  • H ( `laplace.utils.matrix.Kron` ) –

    Kronecker factored Hessian approximation.

Source code in laplace/curvature/curvature.py
def kron(
    self,
    x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
    y: torch.Tensor,
    N: int,
    **kwargs: dict[str, Any],
) -> tuple[torch.Tensor, Kron]:
    """Compute a Kronecker factored curvature approximation (such as KFAC).
    The approximation to \\(H\\) takes the form of two Kronecker factors \\(Q, H\\),
    i.e., \\(H \\approx Q \\otimes H\\) for each Module in the neural network permitting
    such curvature.
    \\(Q\\) is quadratic in the input-dimension of a module \\(p_{in} \\times p_{in}\\)
    and \\(H\\) in the output-dimension \\(p_{out} \\times p_{out}\\).

    Parameters
    ----------
    x : torch.Tensor
        input data `(batch, input_shape)`
    y : torch.Tensor
        labels `(batch, label_shape)`
    N : int
        total number of data points

    Returns
    -------
    loss : torch.Tensor
    H : `laplace.utils.matrix.Kron`
        Kronecker factored Hessian approximation.
    """
    raise NotImplementedError

full #

full(x: Tensor | MutableMapping[str, Tensor | Any], y: Tensor, **kwargs: dict[str, Any]) -> tuple[Tensor, Tensor]

Compute the full EF \(P \times P\) matrix as Hessian approximation \(H_{ef}\) with respect to parameters \(\theta \in \mathbb{R}^P\). For last-layer, reduced to \(\theta_{last}\)

Parameters:

  • x (Tensor) –

    input data (batch, input_shape)

  • y (Tensor) –

    labels (batch, label_shape)

Returns:

  • loss ( Tensor ) –
  • H_ef ( Tensor ) –

    EF (parameters, parameters)

Source code in laplace/curvature/curvature.py
def full(
    self,
    x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
    y: torch.Tensor,
    **kwargs: dict[str, Any],
) -> tuple[torch.Tensor, torch.Tensor]:
    """Compute the full EF \\(P \\times P\\) matrix as Hessian approximation
    \\(H_{ef}\\) with respect to parameters \\(\\theta \\in \\mathbb{R}^P\\).
    For last-layer, reduced to \\(\\theta_{last}\\)

    Parameters
    ----------
    x : torch.Tensor
        input data `(batch, input_shape)`
    y : torch.Tensor
        labels `(batch, label_shape)`

    Returns
    -------
    loss : torch.Tensor
    H_ef : torch.Tensor
        EF `(parameters, parameters)`
    """
    Gs, loss = self.gradients(x, y)
    Gs, loss = Gs.detach(), loss.detach()
    H_ef = torch.einsum("bp,bq->pq", Gs, Gs)
    return self.factor * loss.detach(), self.factor * H_ef

AsdlInterface #

AsdlInterface(model: Module, likelihood: Likelihood | str, last_layer: bool = False, subnetwork_indices: LongTensor | None = None, dict_key_x: str = 'input_ids', dict_key_y: str = 'labels')

Bases: CurvatureInterface

Interface for asdfghjkl backend.

Source code in laplace/curvature/asdl.py
def __init__(
    self,
    model: nn.Module,
    likelihood: Likelihood | str,
    last_layer: bool = False,
    subnetwork_indices: torch.LongTensor | None = None,
    dict_key_x: str = "input_ids",
    dict_key_y: str = "labels",
):
    super().__init__(
        model, likelihood, last_layer, subnetwork_indices, dict_key_x, dict_key_y
    )

last_layer_jacobians #

last_layer_jacobians(x: Tensor | MutableMapping[str, Tensor | Any], enable_backprop: bool = False) -> tuple[Tensor, Tensor]

Compute Jacobians \(\nabla_{\theta_\textrm{last}} f(x;\theta_\textrm{last})\) only at current last-layer parameter \(\theta_{\textrm{last}}\).

Parameters:

  • x (Tensor) –
  • enable_backprop (bool, default: False ) –

Returns:

  • Js ( Tensor ) –

    Jacobians (batch, outputs, last-layer-parameters)

  • f ( Tensor ) –

    output function (batch, outputs)

Source code in laplace/curvature/curvature.py
def last_layer_jacobians(
    self,
    x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
    enable_backprop: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Compute Jacobians \\(\\nabla_{\\theta_\\textrm{last}} f(x;\\theta_\\textrm{last})\\)
    only at current last-layer parameter \\(\\theta_{\\textrm{last}}\\).

    Parameters
    ----------
    x : torch.Tensor
    enable_backprop : bool, default=False

    Returns
    -------
    Js : torch.Tensor
        Jacobians `(batch, outputs, last-layer-parameters)`
    f : torch.Tensor
        output function `(batch, outputs)`
    """
    f, phi = self.model.forward_with_features(x)
    bsize = phi.shape[0]
    output_size = int(f.numel() / bsize)

    # calculate Jacobians using the feature vector 'phi'
    identity = (
        torch.eye(output_size, device=next(self.model.parameters()).device)
        .unsqueeze(0)
        .tile(bsize, 1, 1)
    )
    # Jacobians are batch x output x params
    Js = torch.einsum("kp,kij->kijp", phi, identity).reshape(bsize, output_size, -1)
    if self.model.last_layer.bias is not None:
        Js = torch.cat([Js, identity], dim=2)

    return (Js, f) if enable_backprop else (Js.detach(), f.detach())

full #

full(x: Tensor | MutableMapping[str, Tensor | Any], y: Tensor, **kwargs: dict[str, Any])

Compute a dense curvature (approximation) in the form of a \(P \times P\) matrix \(H\) with respect to parameters \(\theta \in \mathbb{R}^P\).

Parameters:

  • x (Tensor) –

    input data (batch, input_shape)

  • y (Tensor) –

    labels (batch, label_shape)

Returns:

  • loss ( Tensor ) –
  • H ( Tensor ) –

    Hessian approximation (parameters, parameters)

Source code in laplace/curvature/curvature.py
def full(
    self,
    x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
    y: torch.Tensor,
    **kwargs: dict[str, Any],
):
    """Compute a dense curvature (approximation) in the form of a \\(P \\times P\\) matrix
    \\(H\\) with respect to parameters \\(\\theta \\in \\mathbb{R}^P\\).

    Parameters
    ----------
    x : torch.Tensor
        input data `(batch, input_shape)`
    y : torch.Tensor
        labels `(batch, label_shape)`

    Returns
    -------
    loss : torch.Tensor
    H : torch.Tensor
        Hessian approximation `(parameters, parameters)`
    """
    raise NotImplementedError

jacobians #

jacobians(x: Tensor | MutableMapping[str, Tensor | Any], enable_backprop: bool = False) -> tuple[Tensor, Tensor]

Compute Jacobians \(\nabla_\theta f(x;\theta)\) at current parameter \(\theta\) using asdfghjkl's gradient per output dimension.

Parameters:

  • x (Tensor or MutableMapping(dict, UserDict)) –

    input data (batch, input_shape) on compatible device with model if torch.Tensor. If MutableMapping, then at least contains self.dict_key_x. The latter is specific for reward modeling.

  • enable_backprop (bool, default: = False ) –

    whether to enable backprop through the Js and f w.r.t. x

Returns:

  • Js ( Tensor ) –

    Jacobians (batch, parameters, outputs)

  • f ( Tensor ) –

    output function (batch, outputs)

Source code in laplace/curvature/asdl.py
def jacobians(
    self,
    x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
    enable_backprop: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Compute Jacobians \\(\\nabla_\\theta f(x;\\theta)\\) at current parameter \\(\\theta\\)
    using asdfghjkl's gradient per output dimension.

    Parameters
    ----------
    x : torch.Tensor or MutableMapping (e.g. dict, UserDict)
        input data `(batch, input_shape)` on compatible device with model if torch.Tensor.
        If MutableMapping, then at least contains `self.dict_key_x`.
        The latter is specific for reward modeling.
    enable_backprop : bool, default = False
        whether to enable backprop through the Js and f w.r.t. x

    Returns
    -------
    Js : torch.Tensor
        Jacobians `(batch, parameters, outputs)`
    f : torch.Tensor
        output function `(batch, outputs)`
    """
    Js = list()
    for i in range(self.model.output_size):

        def closure():
            self.model.zero_grad()
            f = self.model(x)
            loss = f[:, i].sum()
            loss.backward(
                create_graph=enable_backprop, retain_graph=enable_backprop
            )
            return f

        Ji, f = batch_gradient(
            self.model,
            closure,
            return_outputs=True,
            batch_size=self._get_batch_size(x),
        )
        if self.subnetwork_indices is not None:
            Ji = Ji[:, self.subnetwork_indices]
        Js.append(Ji)
    Js = torch.stack(Js, dim=1)
    return Js, f

gradients #

gradients(x: Tensor | MutableMapping[str, Tensor | Any], y: Tensor) -> tuple[Tensor, Tensor]

Compute gradients \(\nabla_\theta \ell(f(x;\theta, y)\) at current parameter \(\theta\) using asdfghjkl's backend.

Parameters:

  • x (Tensor) –

    input data (batch, input_shape) on compatible device with model.

  • y (Tensor) –

Returns:

  • loss ( Tensor ) –
  • Gs ( Tensor ) –

    gradients (batch, parameters)

Source code in laplace/curvature/asdl.py
def gradients(
    self, x: torch.Tensor | MutableMapping[str, torch.Tensor | Any], y: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
    """Compute gradients \\(\\nabla_\\theta \\ell(f(x;\\theta, y)\\) at current parameter
    \\(\\theta\\) using asdfghjkl's backend.

    Parameters
    ----------
    x : torch.Tensor
        input data `(batch, input_shape)` on compatible device with model.
    y : torch.Tensor

    Returns
    -------
    loss : torch.Tensor
    Gs : torch.Tensor
        gradients `(batch, parameters)`
    """

    def closure():
        self.model.zero_grad()
        loss = self.lossfunc(self.model(x), y)
        loss.backward()
        return loss

    Gs, loss = batch_gradient(
        self.model, closure, return_outputs=True, batch_size=self._get_batch_size(x)
    )
    if self.subnetwork_indices is not None:
        Gs = Gs[:, self.subnetwork_indices]
    return Gs, loss

_get_batch_size #

_get_batch_size(x: Tensor | MutableMapping[str, Tensor | Any]) -> int | None

ASDL assumes that all leading dimensions are the batch size by default (batch_size = None). Here, we want to specify that only the first dimension is the actual batch size. This is the case for LLMs.

Source code in laplace/curvature/asdl.py
def _get_batch_size(
    self,
    x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
) -> int | None:
    """
    ASDL assumes that all leading dimensions are the batch size by default (batch_size = None).
    Here, we want to specify that only the first dimension is the actual batch size.
    This is the case for LLMs.
    """
    if isinstance(x, MutableMapping):
        return x[self.dict_key_x].shape[0]
    else:
        return None  # Use ASDL default behavior

AsdlGGN #

AsdlGGN(model: Module, likelihood: Likelihood | str, last_layer: bool = False, subnetwork_indices: LongTensor | None = None, dict_key_x: str = 'input_ids', dict_key_y: str = 'labels', stochastic: bool = False)

Bases: AsdlInterface, GGNInterface

Implementation of the GGNInterface using asdfghjkl.

Source code in laplace/curvature/asdl.py
def __init__(
    self,
    model: nn.Module,
    likelihood: Likelihood | str,
    last_layer: bool = False,
    subnetwork_indices: torch.LongTensor | None = None,
    dict_key_x: str = "input_ids",
    dict_key_y: str = "labels",
    stochastic: bool = False,
):
    super().__init__(
        model, likelihood, last_layer, subnetwork_indices, dict_key_x, dict_key_y
    )
    self.stochastic = stochastic

jacobians #

jacobians(x: Tensor | MutableMapping[str, Tensor | Any], enable_backprop: bool = False) -> tuple[Tensor, Tensor]

Compute Jacobians \(\nabla_\theta f(x;\theta)\) at current parameter \(\theta\) using asdfghjkl's gradient per output dimension.

Parameters:

  • x (Tensor or MutableMapping(dict, UserDict)) –

    input data (batch, input_shape) on compatible device with model if torch.Tensor. If MutableMapping, then at least contains self.dict_key_x. The latter is specific for reward modeling.

  • enable_backprop (bool, default: = False ) –

    whether to enable backprop through the Js and f w.r.t. x

Returns:

  • Js ( Tensor ) –

    Jacobians (batch, parameters, outputs)

  • f ( Tensor ) –

    output function (batch, outputs)

Source code in laplace/curvature/asdl.py
def jacobians(
    self,
    x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
    enable_backprop: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Compute Jacobians \\(\\nabla_\\theta f(x;\\theta)\\) at current parameter \\(\\theta\\)
    using asdfghjkl's gradient per output dimension.

    Parameters
    ----------
    x : torch.Tensor or MutableMapping (e.g. dict, UserDict)
        input data `(batch, input_shape)` on compatible device with model if torch.Tensor.
        If MutableMapping, then at least contains `self.dict_key_x`.
        The latter is specific for reward modeling.
    enable_backprop : bool, default = False
        whether to enable backprop through the Js and f w.r.t. x

    Returns
    -------
    Js : torch.Tensor
        Jacobians `(batch, parameters, outputs)`
    f : torch.Tensor
        output function `(batch, outputs)`
    """
    Js = list()
    for i in range(self.model.output_size):

        def closure():
            self.model.zero_grad()
            f = self.model(x)
            loss = f[:, i].sum()
            loss.backward(
                create_graph=enable_backprop, retain_graph=enable_backprop
            )
            return f

        Ji, f = batch_gradient(
            self.model,
            closure,
            return_outputs=True,
            batch_size=self._get_batch_size(x),
        )
        if self.subnetwork_indices is not None:
            Ji = Ji[:, self.subnetwork_indices]
        Js.append(Ji)
    Js = torch.stack(Js, dim=1)
    return Js, f

last_layer_jacobians #

last_layer_jacobians(x: Tensor | MutableMapping[str, Tensor | Any], enable_backprop: bool = False) -> tuple[Tensor, Tensor]

Compute Jacobians \(\nabla_{\theta_\textrm{last}} f(x;\theta_\textrm{last})\) only at current last-layer parameter \(\theta_{\textrm{last}}\).

Parameters:

  • x (Tensor) –
  • enable_backprop (bool, default: False ) –

Returns:

  • Js ( Tensor ) –

    Jacobians (batch, outputs, last-layer-parameters)

  • f ( Tensor ) –

    output function (batch, outputs)

Source code in laplace/curvature/curvature.py
def last_layer_jacobians(
    self,
    x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
    enable_backprop: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Compute Jacobians \\(\\nabla_{\\theta_\\textrm{last}} f(x;\\theta_\\textrm{last})\\)
    only at current last-layer parameter \\(\\theta_{\\textrm{last}}\\).

    Parameters
    ----------
    x : torch.Tensor
    enable_backprop : bool, default=False

    Returns
    -------
    Js : torch.Tensor
        Jacobians `(batch, outputs, last-layer-parameters)`
    f : torch.Tensor
        output function `(batch, outputs)`
    """
    f, phi = self.model.forward_with_features(x)
    bsize = phi.shape[0]
    output_size = int(f.numel() / bsize)

    # calculate Jacobians using the feature vector 'phi'
    identity = (
        torch.eye(output_size, device=next(self.model.parameters()).device)
        .unsqueeze(0)
        .tile(bsize, 1, 1)
    )
    # Jacobians are batch x output x params
    Js = torch.einsum("kp,kij->kijp", phi, identity).reshape(bsize, output_size, -1)
    if self.model.last_layer.bias is not None:
        Js = torch.cat([Js, identity], dim=2)

    return (Js, f) if enable_backprop else (Js.detach(), f.detach())

gradients #

gradients(x: Tensor | MutableMapping[str, Tensor | Any], y: Tensor) -> tuple[Tensor, Tensor]

Compute gradients \(\nabla_\theta \ell(f(x;\theta, y)\) at current parameter \(\theta\) using asdfghjkl's backend.

Parameters:

  • x (Tensor) –

    input data (batch, input_shape) on compatible device with model.

  • y (Tensor) –

Returns:

  • loss ( Tensor ) –
  • Gs ( Tensor ) –

    gradients (batch, parameters)

Source code in laplace/curvature/asdl.py
def gradients(
    self, x: torch.Tensor | MutableMapping[str, torch.Tensor | Any], y: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
    """Compute gradients \\(\\nabla_\\theta \\ell(f(x;\\theta, y)\\) at current parameter
    \\(\\theta\\) using asdfghjkl's backend.

    Parameters
    ----------
    x : torch.Tensor
        input data `(batch, input_shape)` on compatible device with model.
    y : torch.Tensor

    Returns
    -------
    loss : torch.Tensor
    Gs : torch.Tensor
        gradients `(batch, parameters)`
    """

    def closure():
        self.model.zero_grad()
        loss = self.lossfunc(self.model(x), y)
        loss.backward()
        return loss

    Gs, loss = batch_gradient(
        self.model, closure, return_outputs=True, batch_size=self._get_batch_size(x)
    )
    if self.subnetwork_indices is not None:
        Gs = Gs[:, self.subnetwork_indices]
    return Gs, loss

full #

full(x: Tensor | MutableMapping[str, Tensor | Any], y: Tensor, **kwargs: dict[str, Any]) -> tuple[Tensor, Tensor]

Compute the full GGN \(P \times P\) matrix as Hessian approximation \(H_{ggn}\) with respect to parameters \(\theta \in \mathbb{R}^P\). For last-layer, reduced to \(\theta_{last}\)

Parameters:

  • x (Tensor) –

    input data (batch, input_shape)

  • y (Tensor) –

    labels (batch, label_shape)

Returns:

  • loss ( Tensor ) –
  • H ( Tensor ) –

    GGN (parameters, parameters)

Source code in laplace/curvature/curvature.py
def full(
    self,
    x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
    y: torch.Tensor,
    **kwargs: dict[str, Any],
) -> tuple[torch.Tensor, torch.Tensor]:
    """Compute the full GGN \\(P \\times P\\) matrix as Hessian approximation
    \\(H_{ggn}\\) with respect to parameters \\(\\theta \\in \\mathbb{R}^P\\).
    For last-layer, reduced to \\(\\theta_{last}\\)

    Parameters
    ----------
    x : torch.Tensor
        input data `(batch, input_shape)`
    y : torch.Tensor
        labels `(batch, label_shape)`

    Returns
    -------
    loss : torch.Tensor
    H : torch.Tensor
        GGN `(parameters, parameters)`
    """
    Js, f = self.last_layer_jacobians(x) if self.last_layer else self.jacobians(x)
    H_lik = (
        self._get_mc_functional_fisher(f)
        if self.stochastic
        else self._get_functional_hessian(f)
    )

    if H_lik is not None:
        H = torch.einsum("bcp,bck,bkq->pq", Js, H_lik, Js)
    else:  # The case of exact GGN for regression
        H = torch.einsum("bcp,bcq->pq", Js, Js)
    loss = self.factor * self.lossfunc(f, y)

    return loss.detach(), H.detach()

_get_mc_functional_fisher #

_get_mc_functional_fisher(f: Tensor) -> Tensor

Approximate the Fisher's middle matrix (expected outer product of the functional gradient) using MC integral with self.num_samples many samples.

Source code in laplace/curvature/curvature.py
def _get_mc_functional_fisher(self, f: torch.Tensor) -> torch.Tensor:
    """Approximate the Fisher's middle matrix (expected outer product of the functional gradient)
    using MC integral with `self.num_samples` many samples.
    """
    F = 0

    for _ in range(self.num_samples):
        if self.likelihood == "regression":
            y_sample = f + torch.randn(f.shape, device=f.device)  # N(y | f, 1)
            grad_sample = f - y_sample  # functional MSE grad
        else:  # classification with softmax
            y_sample = torch.distributions.Multinomial(logits=f).sample()
            # First functional derivative of the loglik is p - y
            p = torch.softmax(f, dim=-1)
            grad_sample = p - y_sample

        F += (
            1
            / self.num_samples
            * torch.einsum("bc,bk->bck", grad_sample, grad_sample)
        )

    return F

_get_batch_size #

_get_batch_size(x: Tensor | MutableMapping[str, Tensor | Any]) -> int | None

ASDL assumes that all leading dimensions are the batch size by default (batch_size = None). Here, we want to specify that only the first dimension is the actual batch size. This is the case for LLMs.

Source code in laplace/curvature/asdl.py
def _get_batch_size(
    self,
    x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
) -> int | None:
    """
    ASDL assumes that all leading dimensions are the batch size by default (batch_size = None).
    Here, we want to specify that only the first dimension is the actual batch size.
    This is the case for LLMs.
    """
    if isinstance(x, MutableMapping):
        return x[self.dict_key_x].shape[0]
    else:
        return None  # Use ASDL default behavior

AsdlEF #

AsdlEF(model: Module, likelihood: Likelihood | str, last_layer: bool = False, dict_key_x: str = 'input_ids', dict_key_y: str = 'labels')

Bases: AsdlInterface, EFInterface

Implementation of the EFInterface using asdfghjkl.

Source code in laplace/curvature/asdl.py
def __init__(
    self,
    model: nn.Module,
    likelihood: Likelihood | str,
    last_layer: bool = False,
    dict_key_x: str = "input_ids",
    dict_key_y: str = "labels",
):
    super().__init__(model, likelihood, last_layer, None, dict_key_x, dict_key_y)

jacobians #

jacobians(x: Tensor | MutableMapping[str, Tensor | Any], enable_backprop: bool = False) -> tuple[Tensor, Tensor]

Compute Jacobians \(\nabla_\theta f(x;\theta)\) at current parameter \(\theta\) using asdfghjkl's gradient per output dimension.

Parameters:

  • x (Tensor or MutableMapping(dict, UserDict)) –

    input data (batch, input_shape) on compatible device with model if torch.Tensor. If MutableMapping, then at least contains self.dict_key_x. The latter is specific for reward modeling.

  • enable_backprop (bool, default: = False ) –

    whether to enable backprop through the Js and f w.r.t. x

Returns:

  • Js ( Tensor ) –

    Jacobians (batch, parameters, outputs)

  • f ( Tensor ) –

    output function (batch, outputs)

Source code in laplace/curvature/asdl.py
def jacobians(
    self,
    x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
    enable_backprop: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Compute Jacobians \\(\\nabla_\\theta f(x;\\theta)\\) at current parameter \\(\\theta\\)
    using asdfghjkl's gradient per output dimension.

    Parameters
    ----------
    x : torch.Tensor or MutableMapping (e.g. dict, UserDict)
        input data `(batch, input_shape)` on compatible device with model if torch.Tensor.
        If MutableMapping, then at least contains `self.dict_key_x`.
        The latter is specific for reward modeling.
    enable_backprop : bool, default = False
        whether to enable backprop through the Js and f w.r.t. x

    Returns
    -------
    Js : torch.Tensor
        Jacobians `(batch, parameters, outputs)`
    f : torch.Tensor
        output function `(batch, outputs)`
    """
    Js = list()
    for i in range(self.model.output_size):

        def closure():
            self.model.zero_grad()
            f = self.model(x)
            loss = f[:, i].sum()
            loss.backward(
                create_graph=enable_backprop, retain_graph=enable_backprop
            )
            return f

        Ji, f = batch_gradient(
            self.model,
            closure,
            return_outputs=True,
            batch_size=self._get_batch_size(x),
        )
        if self.subnetwork_indices is not None:
            Ji = Ji[:, self.subnetwork_indices]
        Js.append(Ji)
    Js = torch.stack(Js, dim=1)
    return Js, f

last_layer_jacobians #

last_layer_jacobians(x: Tensor | MutableMapping[str, Tensor | Any], enable_backprop: bool = False) -> tuple[Tensor, Tensor]

Compute Jacobians \(\nabla_{\theta_\textrm{last}} f(x;\theta_\textrm{last})\) only at current last-layer parameter \(\theta_{\textrm{last}}\).

Parameters:

  • x (Tensor) –
  • enable_backprop (bool, default: False ) –

Returns:

  • Js ( Tensor ) –

    Jacobians (batch, outputs, last-layer-parameters)

  • f ( Tensor ) –

    output function (batch, outputs)

Source code in laplace/curvature/curvature.py
def last_layer_jacobians(
    self,
    x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
    enable_backprop: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Compute Jacobians \\(\\nabla_{\\theta_\\textrm{last}} f(x;\\theta_\\textrm{last})\\)
    only at current last-layer parameter \\(\\theta_{\\textrm{last}}\\).

    Parameters
    ----------
    x : torch.Tensor
    enable_backprop : bool, default=False

    Returns
    -------
    Js : torch.Tensor
        Jacobians `(batch, outputs, last-layer-parameters)`
    f : torch.Tensor
        output function `(batch, outputs)`
    """
    f, phi = self.model.forward_with_features(x)
    bsize = phi.shape[0]
    output_size = int(f.numel() / bsize)

    # calculate Jacobians using the feature vector 'phi'
    identity = (
        torch.eye(output_size, device=next(self.model.parameters()).device)
        .unsqueeze(0)
        .tile(bsize, 1, 1)
    )
    # Jacobians are batch x output x params
    Js = torch.einsum("kp,kij->kijp", phi, identity).reshape(bsize, output_size, -1)
    if self.model.last_layer.bias is not None:
        Js = torch.cat([Js, identity], dim=2)

    return (Js, f) if enable_backprop else (Js.detach(), f.detach())

gradients #

gradients(x: Tensor | MutableMapping[str, Tensor | Any], y: Tensor) -> tuple[Tensor, Tensor]

Compute gradients \(\nabla_\theta \ell(f(x;\theta, y)\) at current parameter \(\theta\) using asdfghjkl's backend.

Parameters:

  • x (Tensor) –

    input data (batch, input_shape) on compatible device with model.

  • y (Tensor) –

Returns:

  • loss ( Tensor ) –
  • Gs ( Tensor ) –

    gradients (batch, parameters)

Source code in laplace/curvature/asdl.py
def gradients(
    self, x: torch.Tensor | MutableMapping[str, torch.Tensor | Any], y: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
    """Compute gradients \\(\\nabla_\\theta \\ell(f(x;\\theta, y)\\) at current parameter
    \\(\\theta\\) using asdfghjkl's backend.

    Parameters
    ----------
    x : torch.Tensor
        input data `(batch, input_shape)` on compatible device with model.
    y : torch.Tensor

    Returns
    -------
    loss : torch.Tensor
    Gs : torch.Tensor
        gradients `(batch, parameters)`
    """

    def closure():
        self.model.zero_grad()
        loss = self.lossfunc(self.model(x), y)
        loss.backward()
        return loss

    Gs, loss = batch_gradient(
        self.model, closure, return_outputs=True, batch_size=self._get_batch_size(x)
    )
    if self.subnetwork_indices is not None:
        Gs = Gs[:, self.subnetwork_indices]
    return Gs, loss

full #

full(x: Tensor | MutableMapping[str, Tensor | Any], y: Tensor, **kwargs: dict[str, Any]) -> tuple[Tensor, Tensor]

Compute the full EF \(P \times P\) matrix as Hessian approximation \(H_{ef}\) with respect to parameters \(\theta \in \mathbb{R}^P\). For last-layer, reduced to \(\theta_{last}\)

Parameters:

  • x (Tensor) –

    input data (batch, input_shape)

  • y (Tensor) –

    labels (batch, label_shape)

Returns:

  • loss ( Tensor ) –
  • H_ef ( Tensor ) –

    EF (parameters, parameters)

Source code in laplace/curvature/curvature.py
def full(
    self,
    x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
    y: torch.Tensor,
    **kwargs: dict[str, Any],
) -> tuple[torch.Tensor, torch.Tensor]:
    """Compute the full EF \\(P \\times P\\) matrix as Hessian approximation
    \\(H_{ef}\\) with respect to parameters \\(\\theta \\in \\mathbb{R}^P\\).
    For last-layer, reduced to \\(\\theta_{last}\\)

    Parameters
    ----------
    x : torch.Tensor
        input data `(batch, input_shape)`
    y : torch.Tensor
        labels `(batch, label_shape)`

    Returns
    -------
    loss : torch.Tensor
    H_ef : torch.Tensor
        EF `(parameters, parameters)`
    """
    Gs, loss = self.gradients(x, y)
    Gs, loss = Gs.detach(), loss.detach()
    H_ef = torch.einsum("bp,bq->pq", Gs, Gs)
    return self.factor * loss.detach(), self.factor * H_ef

_get_batch_size #

_get_batch_size(x: Tensor | MutableMapping[str, Tensor | Any]) -> int | None

ASDL assumes that all leading dimensions are the batch size by default (batch_size = None). Here, we want to specify that only the first dimension is the actual batch size. This is the case for LLMs.

Source code in laplace/curvature/asdl.py
def _get_batch_size(
    self,
    x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
) -> int | None:
    """
    ASDL assumes that all leading dimensions are the batch size by default (batch_size = None).
    Here, we want to specify that only the first dimension is the actual batch size.
    This is the case for LLMs.
    """
    if isinstance(x, MutableMapping):
        return x[self.dict_key_x].shape[0]
    else:
        return None  # Use ASDL default behavior

AsdlHessian #

AsdlHessian(model: Module, likelihood: Likelihood | str, last_layer: bool = False, dict_key_x: str = 'input_ids', dict_key_y: str = 'labels')

Bases: AsdlInterface

Source code in laplace/curvature/asdl.py
def __init__(
    self,
    model: nn.Module,
    likelihood: Likelihood | str,
    last_layer: bool = False,
    dict_key_x: str = "input_ids",
    dict_key_y: str = "labels",
) -> None:
    super().__init__(
        model,
        likelihood,
        last_layer,
        subnetwork_indices=None,
        dict_key_x=dict_key_x,
        dict_key_y=dict_key_y,
    )

jacobians #

jacobians(x: Tensor | MutableMapping[str, Tensor | Any], enable_backprop: bool = False) -> tuple[Tensor, Tensor]

Compute Jacobians \(\nabla_\theta f(x;\theta)\) at current parameter \(\theta\) using asdfghjkl's gradient per output dimension.

Parameters:

  • x (Tensor or MutableMapping(dict, UserDict)) –

    input data (batch, input_shape) on compatible device with model if torch.Tensor. If MutableMapping, then at least contains self.dict_key_x. The latter is specific for reward modeling.

  • enable_backprop (bool, default: = False ) –

    whether to enable backprop through the Js and f w.r.t. x

Returns:

  • Js ( Tensor ) –

    Jacobians (batch, parameters, outputs)

  • f ( Tensor ) –

    output function (batch, outputs)

Source code in laplace/curvature/asdl.py
def jacobians(
    self,
    x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
    enable_backprop: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Compute Jacobians \\(\\nabla_\\theta f(x;\\theta)\\) at current parameter \\(\\theta\\)
    using asdfghjkl's gradient per output dimension.

    Parameters
    ----------
    x : torch.Tensor or MutableMapping (e.g. dict, UserDict)
        input data `(batch, input_shape)` on compatible device with model if torch.Tensor.
        If MutableMapping, then at least contains `self.dict_key_x`.
        The latter is specific for reward modeling.
    enable_backprop : bool, default = False
        whether to enable backprop through the Js and f w.r.t. x

    Returns
    -------
    Js : torch.Tensor
        Jacobians `(batch, parameters, outputs)`
    f : torch.Tensor
        output function `(batch, outputs)`
    """
    Js = list()
    for i in range(self.model.output_size):

        def closure():
            self.model.zero_grad()
            f = self.model(x)
            loss = f[:, i].sum()
            loss.backward(
                create_graph=enable_backprop, retain_graph=enable_backprop
            )
            return f

        Ji, f = batch_gradient(
            self.model,
            closure,
            return_outputs=True,
            batch_size=self._get_batch_size(x),
        )
        if self.subnetwork_indices is not None:
            Ji = Ji[:, self.subnetwork_indices]
        Js.append(Ji)
    Js = torch.stack(Js, dim=1)
    return Js, f

last_layer_jacobians #

last_layer_jacobians(x: Tensor | MutableMapping[str, Tensor | Any], enable_backprop: bool = False) -> tuple[Tensor, Tensor]

Compute Jacobians \(\nabla_{\theta_\textrm{last}} f(x;\theta_\textrm{last})\) only at current last-layer parameter \(\theta_{\textrm{last}}\).

Parameters:

  • x (Tensor) –
  • enable_backprop (bool, default: False ) –

Returns:

  • Js ( Tensor ) –

    Jacobians (batch, outputs, last-layer-parameters)

  • f ( Tensor ) –

    output function (batch, outputs)

Source code in laplace/curvature/curvature.py
def last_layer_jacobians(
    self,
    x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
    enable_backprop: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Compute Jacobians \\(\\nabla_{\\theta_\\textrm{last}} f(x;\\theta_\\textrm{last})\\)
    only at current last-layer parameter \\(\\theta_{\\textrm{last}}\\).

    Parameters
    ----------
    x : torch.Tensor
    enable_backprop : bool, default=False

    Returns
    -------
    Js : torch.Tensor
        Jacobians `(batch, outputs, last-layer-parameters)`
    f : torch.Tensor
        output function `(batch, outputs)`
    """
    f, phi = self.model.forward_with_features(x)
    bsize = phi.shape[0]
    output_size = int(f.numel() / bsize)

    # calculate Jacobians using the feature vector 'phi'
    identity = (
        torch.eye(output_size, device=next(self.model.parameters()).device)
        .unsqueeze(0)
        .tile(bsize, 1, 1)
    )
    # Jacobians are batch x output x params
    Js = torch.einsum("kp,kij->kijp", phi, identity).reshape(bsize, output_size, -1)
    if self.model.last_layer.bias is not None:
        Js = torch.cat([Js, identity], dim=2)

    return (Js, f) if enable_backprop else (Js.detach(), f.detach())

gradients #

gradients(x: Tensor | MutableMapping[str, Tensor | Any], y: Tensor) -> tuple[Tensor, Tensor]

Compute gradients \(\nabla_\theta \ell(f(x;\theta, y)\) at current parameter \(\theta\) using asdfghjkl's backend.

Parameters:

  • x (Tensor) –

    input data (batch, input_shape) on compatible device with model.

  • y (Tensor) –

Returns:

  • loss ( Tensor ) –
  • Gs ( Tensor ) –

    gradients (batch, parameters)

Source code in laplace/curvature/asdl.py
def gradients(
    self, x: torch.Tensor | MutableMapping[str, torch.Tensor | Any], y: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
    """Compute gradients \\(\\nabla_\\theta \\ell(f(x;\\theta, y)\\) at current parameter
    \\(\\theta\\) using asdfghjkl's backend.

    Parameters
    ----------
    x : torch.Tensor
        input data `(batch, input_shape)` on compatible device with model.
    y : torch.Tensor

    Returns
    -------
    loss : torch.Tensor
    Gs : torch.Tensor
        gradients `(batch, parameters)`
    """

    def closure():
        self.model.zero_grad()
        loss = self.lossfunc(self.model(x), y)
        loss.backward()
        return loss

    Gs, loss = batch_gradient(
        self.model, closure, return_outputs=True, batch_size=self._get_batch_size(x)
    )
    if self.subnetwork_indices is not None:
        Gs = Gs[:, self.subnetwork_indices]
    return Gs, loss

_get_batch_size #

_get_batch_size(x: Tensor | MutableMapping[str, Tensor | Any]) -> int | None

ASDL assumes that all leading dimensions are the batch size by default (batch_size = None). Here, we want to specify that only the first dimension is the actual batch size. This is the case for LLMs.

Source code in laplace/curvature/asdl.py
def _get_batch_size(
    self,
    x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
) -> int | None:
    """
    ASDL assumes that all leading dimensions are the batch size by default (batch_size = None).
    Here, we want to specify that only the first dimension is the actual batch size.
    This is the case for LLMs.
    """
    if isinstance(x, MutableMapping):
        return x[self.dict_key_x].shape[0]
    else:
        return None  # Use ASDL default behavior

BackPackInterface #

BackPackInterface(model: Module, likelihood: Likelihood | str, last_layer: bool = False, subnetwork_indices: LongTensor | None = None, dict_key_x: str = 'input_ids', dict_key_y: str = 'labels')

Bases: CurvatureInterface

Interface for Backpack backend.

Source code in laplace/curvature/backpack.py
def __init__(
    self,
    model: nn.Module,
    likelihood: Likelihood | str,
    last_layer: bool = False,
    subnetwork_indices: torch.LongTensor | None = None,
    dict_key_x: str = "input_ids",
    dict_key_y: str = "labels",
) -> None:
    super().__init__(
        model, likelihood, last_layer, subnetwork_indices, dict_key_x, dict_key_y
    )

    extend(self._model)
    extend(self.lossfunc)

last_layer_jacobians #

last_layer_jacobians(x: Tensor | MutableMapping[str, Tensor | Any], enable_backprop: bool = False) -> tuple[Tensor, Tensor]

Compute Jacobians \(\nabla_{\theta_\textrm{last}} f(x;\theta_\textrm{last})\) only at current last-layer parameter \(\theta_{\textrm{last}}\).

Parameters:

  • x (Tensor) –
  • enable_backprop (bool, default: False ) –

Returns:

  • Js ( Tensor ) –

    Jacobians (batch, outputs, last-layer-parameters)

  • f ( Tensor ) –

    output function (batch, outputs)

Source code in laplace/curvature/curvature.py
def last_layer_jacobians(
    self,
    x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
    enable_backprop: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Compute Jacobians \\(\\nabla_{\\theta_\\textrm{last}} f(x;\\theta_\\textrm{last})\\)
    only at current last-layer parameter \\(\\theta_{\\textrm{last}}\\).

    Parameters
    ----------
    x : torch.Tensor
    enable_backprop : bool, default=False

    Returns
    -------
    Js : torch.Tensor
        Jacobians `(batch, outputs, last-layer-parameters)`
    f : torch.Tensor
        output function `(batch, outputs)`
    """
    f, phi = self.model.forward_with_features(x)
    bsize = phi.shape[0]
    output_size = int(f.numel() / bsize)

    # calculate Jacobians using the feature vector 'phi'
    identity = (
        torch.eye(output_size, device=next(self.model.parameters()).device)
        .unsqueeze(0)
        .tile(bsize, 1, 1)
    )
    # Jacobians are batch x output x params
    Js = torch.einsum("kp,kij->kijp", phi, identity).reshape(bsize, output_size, -1)
    if self.model.last_layer.bias is not None:
        Js = torch.cat([Js, identity], dim=2)

    return (Js, f) if enable_backprop else (Js.detach(), f.detach())

full #

full(x: Tensor | MutableMapping[str, Tensor | Any], y: Tensor, **kwargs: dict[str, Any])

Compute a dense curvature (approximation) in the form of a \(P \times P\) matrix \(H\) with respect to parameters \(\theta \in \mathbb{R}^P\).

Parameters:

  • x (Tensor) –

    input data (batch, input_shape)

  • y (Tensor) –

    labels (batch, label_shape)

Returns:

  • loss ( Tensor ) –
  • H ( Tensor ) –

    Hessian approximation (parameters, parameters)

Source code in laplace/curvature/curvature.py
def full(
    self,
    x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
    y: torch.Tensor,
    **kwargs: dict[str, Any],
):
    """Compute a dense curvature (approximation) in the form of a \\(P \\times P\\) matrix
    \\(H\\) with respect to parameters \\(\\theta \\in \\mathbb{R}^P\\).

    Parameters
    ----------
    x : torch.Tensor
        input data `(batch, input_shape)`
    y : torch.Tensor
        labels `(batch, label_shape)`

    Returns
    -------
    loss : torch.Tensor
    H : torch.Tensor
        Hessian approximation `(parameters, parameters)`
    """
    raise NotImplementedError

kron #

kron(x: Tensor | MutableMapping[str, Tensor | Any], y: Tensor, N: int, **kwargs: dict[str, Any]) -> tuple[Tensor, Kron]

Compute a Kronecker factored curvature approximation (such as KFAC). The approximation to \(H\) takes the form of two Kronecker factors \(Q, H\), i.e., \(H \approx Q \otimes H\) for each Module in the neural network permitting such curvature. \(Q\) is quadratic in the input-dimension of a module \(p_{in} \times p_{in}\) and \(H\) in the output-dimension \(p_{out} \times p_{out}\).

Parameters:

  • x (Tensor) –

    input data (batch, input_shape)

  • y (Tensor) –

    labels (batch, label_shape)

  • N (int) –

    total number of data points

Returns:

  • loss ( Tensor ) –
  • H ( `laplace.utils.matrix.Kron` ) –

    Kronecker factored Hessian approximation.

Source code in laplace/curvature/curvature.py
def kron(
    self,
    x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
    y: torch.Tensor,
    N: int,
    **kwargs: dict[str, Any],
) -> tuple[torch.Tensor, Kron]:
    """Compute a Kronecker factored curvature approximation (such as KFAC).
    The approximation to \\(H\\) takes the form of two Kronecker factors \\(Q, H\\),
    i.e., \\(H \\approx Q \\otimes H\\) for each Module in the neural network permitting
    such curvature.
    \\(Q\\) is quadratic in the input-dimension of a module \\(p_{in} \\times p_{in}\\)
    and \\(H\\) in the output-dimension \\(p_{out} \\times p_{out}\\).

    Parameters
    ----------
    x : torch.Tensor
        input data `(batch, input_shape)`
    y : torch.Tensor
        labels `(batch, label_shape)`
    N : int
        total number of data points

    Returns
    -------
    loss : torch.Tensor
    H : `laplace.utils.matrix.Kron`
        Kronecker factored Hessian approximation.
    """
    raise NotImplementedError

diag #

diag(x: Tensor | MutableMapping[str, Tensor | Any], y: Tensor, **kwargs: dict[str, Any])

Compute a diagonal Hessian approximation to \(H\) and is represented as a vector of the dimensionality of parameters \(\theta\).

Parameters:

  • x (Tensor) –

    input data (batch, input_shape)

  • y (Tensor) –

    labels (batch, label_shape)

Returns:

  • loss ( Tensor ) –
  • H ( Tensor ) –

    vector representing the diagonal of H

Source code in laplace/curvature/curvature.py
def diag(
    self,
    x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
    y: torch.Tensor,
    **kwargs: dict[str, Any],
):
    """Compute a diagonal Hessian approximation to \\(H\\) and is represented as a
    vector of the dimensionality of parameters \\(\\theta\\).

    Parameters
    ----------
    x : torch.Tensor
        input data `(batch, input_shape)`
    y : torch.Tensor
        labels `(batch, label_shape)`

    Returns
    -------
    loss : torch.Tensor
    H : torch.Tensor
        vector representing the diagonal of H
    """
    raise NotImplementedError

jacobians #

jacobians(x: Tensor | MutableMapping[str, Tensor | Any], enable_backprop: bool = False) -> tuple[Tensor, Tensor]

Compute Jacobians \(\nabla_{\theta} f(x;\theta)\) at current parameter \(\theta\) using backpack's BatchGrad per output dimension. Note that BackPACK doesn't play well with torch.func, so this method has to be overridden.

Parameters:

  • x (Tensor) –

    input data (batch, input_shape) on compatible device with model.

  • enable_backprop (bool, default: = False ) –

    whether to enable backprop through the Js and f w.r.t. x

Returns:

  • Js ( Tensor ) –

    Jacobians (batch, parameters, outputs)

  • f ( Tensor ) –

    output function (batch, outputs)

Source code in laplace/curvature/backpack.py
def jacobians(
    self,
    x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
    enable_backprop: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Compute Jacobians \\(\\nabla_{\\theta} f(x;\\theta)\\) at current parameter \\(\\theta\\)
    using backpack's BatchGrad per output dimension. Note that BackPACK doesn't play well
    with torch.func, so this method has to be overridden.

    Parameters
    ----------
    x : torch.Tensor
        input data `(batch, input_shape)` on compatible device with model.
    enable_backprop : bool, default = False
        whether to enable backprop through the Js and f w.r.t. x

    Returns
    -------
    Js : torch.Tensor
        Jacobians `(batch, parameters, outputs)`
    f : torch.Tensor
        output function `(batch, outputs)`
    """
    if isinstance(x, MutableMapping):
        raise ValueError("BackPACK backend does not support dict-like inputs!")

    model = extend(self.model)
    to_stack = []
    for i in range(model.output_size):
        model.zero_grad()
        out = model(x)
        with backpack(BatchGrad()):
            if model.output_size > 1:
                out[:, i].sum().backward(
                    create_graph=enable_backprop, retain_graph=enable_backprop
                )
            else:
                out.sum().backward(
                    create_graph=enable_backprop, retain_graph=enable_backprop
                )
            to_cat = []
            for param in model.parameters():
                to_cat.append(param.grad_batch.reshape(x.shape[0], -1))
                delattr(param, "grad_batch")
            Jk = torch.cat(to_cat, dim=1)
            if self.subnetwork_indices is not None:
                Jk = Jk[:, self.subnetwork_indices]
        to_stack.append(Jk)
        if i == 0:
            f = out

    model.zero_grad()
    CTX.remove_hooks()
    _cleanup(model)
    if model.output_size > 1:
        J = torch.stack(to_stack, dim=2).transpose(1, 2)
    else:
        J = Jk.unsqueeze(-1).transpose(1, 2)

    return (J, f) if enable_backprop else (J.detach(), f.detach())

gradients #

gradients(x: Tensor | MutableMapping[str, Tensor | Any], y: Tensor) -> tuple[Tensor, Tensor]

Compute gradients \(\nabla_\theta \ell(f(x;\theta, y)\) at current parameter \(\theta\) using Backpack's BatchGrad. Note that BackPACK doesn't play well with torch.func, so this method has to be overridden.

Parameters:

  • x (Tensor) –

    input data (batch, input_shape) on compatible device with model.

  • y (Tensor) –

Returns:

  • Gs ( Tensor ) –

    gradients (batch, parameters)

  • loss ( Tensor ) –
Source code in laplace/curvature/backpack.py
def gradients(
    self, x: torch.Tensor | MutableMapping[str, torch.Tensor | Any], y: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
    """Compute gradients \\(\\nabla_\\theta \\ell(f(x;\\theta, y)\\) at current parameter
    \\(\\theta\\) using Backpack's BatchGrad. Note that BackPACK doesn't play well
    with torch.func, so this method has to be overridden.

    Parameters
    ----------
    x : torch.Tensor
        input data `(batch, input_shape)` on compatible device with model.
    y : torch.Tensor

    Returns
    -------
    Gs : torch.Tensor
        gradients `(batch, parameters)`
    loss : torch.Tensor
    """
    f = self.model(x)
    loss = self.lossfunc(f, y)
    with backpack(BatchGrad()):
        loss.backward()
    Gs = torch.cat(
        [p.grad_batch.data.flatten(start_dim=1) for p in self._model.parameters()],
        dim=1,
    )
    if self.subnetwork_indices is not None:
        Gs = Gs[:, self.subnetwork_indices]
    return Gs, loss

BackPackGGN #

BackPackGGN(model: Module, likelihood: Likelihood | str, last_layer: bool = False, subnetwork_indices: LongTensor | None = None, dict_key_x: str = 'input_ids', dict_key_y: str = 'labels', stochastic: bool = False)

Bases: BackPackInterface, GGNInterface

Implementation of the GGNInterface using Backpack.

Source code in laplace/curvature/backpack.py
def __init__(
    self,
    model: nn.Module,
    likelihood: Likelihood | str,
    last_layer: bool = False,
    subnetwork_indices: torch.LongTensor | None = None,
    dict_key_x: str = "input_ids",
    dict_key_y: str = "labels",
    stochastic: bool = False,
):
    super().__init__(
        model, likelihood, last_layer, subnetwork_indices, dict_key_x, dict_key_y
    )
    self.stochastic = stochastic

jacobians #

jacobians(x: Tensor | MutableMapping[str, Tensor | Any], enable_backprop: bool = False) -> tuple[Tensor, Tensor]

Compute Jacobians \(\nabla_{\theta} f(x;\theta)\) at current parameter \(\theta\) using backpack's BatchGrad per output dimension. Note that BackPACK doesn't play well with torch.func, so this method has to be overridden.

Parameters:

  • x (Tensor) –

    input data (batch, input_shape) on compatible device with model.

  • enable_backprop (bool, default: = False ) –

    whether to enable backprop through the Js and f w.r.t. x

Returns:

  • Js ( Tensor ) –

    Jacobians (batch, parameters, outputs)

  • f ( Tensor ) –

    output function (batch, outputs)

Source code in laplace/curvature/backpack.py
def jacobians(
    self,
    x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
    enable_backprop: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Compute Jacobians \\(\\nabla_{\\theta} f(x;\\theta)\\) at current parameter \\(\\theta\\)
    using backpack's BatchGrad per output dimension. Note that BackPACK doesn't play well
    with torch.func, so this method has to be overridden.

    Parameters
    ----------
    x : torch.Tensor
        input data `(batch, input_shape)` on compatible device with model.
    enable_backprop : bool, default = False
        whether to enable backprop through the Js and f w.r.t. x

    Returns
    -------
    Js : torch.Tensor
        Jacobians `(batch, parameters, outputs)`
    f : torch.Tensor
        output function `(batch, outputs)`
    """
    if isinstance(x, MutableMapping):
        raise ValueError("BackPACK backend does not support dict-like inputs!")

    model = extend(self.model)
    to_stack = []
    for i in range(model.output_size):
        model.zero_grad()
        out = model(x)
        with backpack(BatchGrad()):
            if model.output_size > 1:
                out[:, i].sum().backward(
                    create_graph=enable_backprop, retain_graph=enable_backprop
                )
            else:
                out.sum().backward(
                    create_graph=enable_backprop, retain_graph=enable_backprop
                )
            to_cat = []
            for param in model.parameters():
                to_cat.append(param.grad_batch.reshape(x.shape[0], -1))
                delattr(param, "grad_batch")
            Jk = torch.cat(to_cat, dim=1)
            if self.subnetwork_indices is not None:
                Jk = Jk[:, self.subnetwork_indices]
        to_stack.append(Jk)
        if i == 0:
            f = out

    model.zero_grad()
    CTX.remove_hooks()
    _cleanup(model)
    if model.output_size > 1:
        J = torch.stack(to_stack, dim=2).transpose(1, 2)
    else:
        J = Jk.unsqueeze(-1).transpose(1, 2)

    return (J, f) if enable_backprop else (J.detach(), f.detach())

last_layer_jacobians #

last_layer_jacobians(x: Tensor | MutableMapping[str, Tensor | Any], enable_backprop: bool = False) -> tuple[Tensor, Tensor]

Compute Jacobians \(\nabla_{\theta_\textrm{last}} f(x;\theta_\textrm{last})\) only at current last-layer parameter \(\theta_{\textrm{last}}\).

Parameters:

  • x (Tensor) –
  • enable_backprop (bool, default: False ) –

Returns:

  • Js ( Tensor ) –

    Jacobians (batch, outputs, last-layer-parameters)

  • f ( Tensor ) –

    output function (batch, outputs)

Source code in laplace/curvature/curvature.py
def last_layer_jacobians(
    self,
    x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
    enable_backprop: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Compute Jacobians \\(\\nabla_{\\theta_\\textrm{last}} f(x;\\theta_\\textrm{last})\\)
    only at current last-layer parameter \\(\\theta_{\\textrm{last}}\\).

    Parameters
    ----------
    x : torch.Tensor
    enable_backprop : bool, default=False

    Returns
    -------
    Js : torch.Tensor
        Jacobians `(batch, outputs, last-layer-parameters)`
    f : torch.Tensor
        output function `(batch, outputs)`
    """
    f, phi = self.model.forward_with_features(x)
    bsize = phi.shape[0]
    output_size = int(f.numel() / bsize)

    # calculate Jacobians using the feature vector 'phi'
    identity = (
        torch.eye(output_size, device=next(self.model.parameters()).device)
        .unsqueeze(0)
        .tile(bsize, 1, 1)
    )
    # Jacobians are batch x output x params
    Js = torch.einsum("kp,kij->kijp", phi, identity).reshape(bsize, output_size, -1)
    if self.model.last_layer.bias is not None:
        Js = torch.cat([Js, identity], dim=2)

    return (Js, f) if enable_backprop else (Js.detach(), f.detach())

gradients #

gradients(x: Tensor | MutableMapping[str, Tensor | Any], y: Tensor) -> tuple[Tensor, Tensor]

Compute gradients \(\nabla_\theta \ell(f(x;\theta, y)\) at current parameter \(\theta\) using Backpack's BatchGrad. Note that BackPACK doesn't play well with torch.func, so this method has to be overridden.

Parameters:

  • x (Tensor) –

    input data (batch, input_shape) on compatible device with model.

  • y (Tensor) –

Returns:

  • Gs ( Tensor ) –

    gradients (batch, parameters)

  • loss ( Tensor ) –
Source code in laplace/curvature/backpack.py
def gradients(
    self, x: torch.Tensor | MutableMapping[str, torch.Tensor | Any], y: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
    """Compute gradients \\(\\nabla_\\theta \\ell(f(x;\\theta, y)\\) at current parameter
    \\(\\theta\\) using Backpack's BatchGrad. Note that BackPACK doesn't play well
    with torch.func, so this method has to be overridden.

    Parameters
    ----------
    x : torch.Tensor
        input data `(batch, input_shape)` on compatible device with model.
    y : torch.Tensor

    Returns
    -------
    Gs : torch.Tensor
        gradients `(batch, parameters)`
    loss : torch.Tensor
    """
    f = self.model(x)
    loss = self.lossfunc(f, y)
    with backpack(BatchGrad()):
        loss.backward()
    Gs = torch.cat(
        [p.grad_batch.data.flatten(start_dim=1) for p in self._model.parameters()],
        dim=1,
    )
    if self.subnetwork_indices is not None:
        Gs = Gs[:, self.subnetwork_indices]
    return Gs, loss

full #

full(x: Tensor | MutableMapping[str, Tensor | Any], y: Tensor, **kwargs: dict[str, Any]) -> tuple[Tensor, Tensor]

Compute the full GGN \(P \times P\) matrix as Hessian approximation \(H_{ggn}\) with respect to parameters \(\theta \in \mathbb{R}^P\). For last-layer, reduced to \(\theta_{last}\)

Parameters:

  • x (Tensor) –

    input data (batch, input_shape)

  • y (Tensor) –

    labels (batch, label_shape)

Returns:

  • loss ( Tensor ) –
  • H ( Tensor ) –

    GGN (parameters, parameters)

Source code in laplace/curvature/curvature.py
def full(
    self,
    x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
    y: torch.Tensor,
    **kwargs: dict[str, Any],
) -> tuple[torch.Tensor, torch.Tensor]:
    """Compute the full GGN \\(P \\times P\\) matrix as Hessian approximation
    \\(H_{ggn}\\) with respect to parameters \\(\\theta \\in \\mathbb{R}^P\\).
    For last-layer, reduced to \\(\\theta_{last}\\)

    Parameters
    ----------
    x : torch.Tensor
        input data `(batch, input_shape)`
    y : torch.Tensor
        labels `(batch, label_shape)`

    Returns
    -------
    loss : torch.Tensor
    H : torch.Tensor
        GGN `(parameters, parameters)`
    """
    Js, f = self.last_layer_jacobians(x) if self.last_layer else self.jacobians(x)
    H_lik = (
        self._get_mc_functional_fisher(f)
        if self.stochastic
        else self._get_functional_hessian(f)
    )

    if H_lik is not None:
        H = torch.einsum("bcp,bck,bkq->pq", Js, H_lik, Js)
    else:  # The case of exact GGN for regression
        H = torch.einsum("bcp,bcq->pq", Js, Js)
    loss = self.factor * self.lossfunc(f, y)

    return loss.detach(), H.detach()

_get_mc_functional_fisher #

_get_mc_functional_fisher(f: Tensor) -> Tensor

Approximate the Fisher's middle matrix (expected outer product of the functional gradient) using MC integral with self.num_samples many samples.

Source code in laplace/curvature/curvature.py
def _get_mc_functional_fisher(self, f: torch.Tensor) -> torch.Tensor:
    """Approximate the Fisher's middle matrix (expected outer product of the functional gradient)
    using MC integral with `self.num_samples` many samples.
    """
    F = 0

    for _ in range(self.num_samples):
        if self.likelihood == "regression":
            y_sample = f + torch.randn(f.shape, device=f.device)  # N(y | f, 1)
            grad_sample = f - y_sample  # functional MSE grad
        else:  # classification with softmax
            y_sample = torch.distributions.Multinomial(logits=f).sample()
            # First functional derivative of the loglik is p - y
            p = torch.softmax(f, dim=-1)
            grad_sample = p - y_sample

        F += (
            1
            / self.num_samples
            * torch.einsum("bc,bk->bck", grad_sample, grad_sample)
        )

    return F

BackPackEF #

BackPackEF(model: Module, likelihood: Likelihood | str, last_layer: bool = False, subnetwork_indices: LongTensor | None = None, dict_key_x: str = 'input_ids', dict_key_y: str = 'labels')

Bases: BackPackInterface, EFInterface

Implementation of EFInterface using Backpack.

Source code in laplace/curvature/backpack.py
def __init__(
    self,
    model: nn.Module,
    likelihood: Likelihood | str,
    last_layer: bool = False,
    subnetwork_indices: torch.LongTensor | None = None,
    dict_key_x: str = "input_ids",
    dict_key_y: str = "labels",
) -> None:
    super().__init__(
        model, likelihood, last_layer, subnetwork_indices, dict_key_x, dict_key_y
    )

    extend(self._model)
    extend(self.lossfunc)

jacobians #

jacobians(x: Tensor | MutableMapping[str, Tensor | Any], enable_backprop: bool = False) -> tuple[Tensor, Tensor]

Compute Jacobians \(\nabla_{\theta} f(x;\theta)\) at current parameter \(\theta\) using backpack's BatchGrad per output dimension. Note that BackPACK doesn't play well with torch.func, so this method has to be overridden.

Parameters:

  • x (Tensor) –

    input data (batch, input_shape) on compatible device with model.

  • enable_backprop (bool, default: = False ) –

    whether to enable backprop through the Js and f w.r.t. x

Returns:

  • Js ( Tensor ) –

    Jacobians (batch, parameters, outputs)

  • f ( Tensor ) –

    output function (batch, outputs)

Source code in laplace/curvature/backpack.py
def jacobians(
    self,
    x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
    enable_backprop: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Compute Jacobians \\(\\nabla_{\\theta} f(x;\\theta)\\) at current parameter \\(\\theta\\)
    using backpack's BatchGrad per output dimension. Note that BackPACK doesn't play well
    with torch.func, so this method has to be overridden.

    Parameters
    ----------
    x : torch.Tensor
        input data `(batch, input_shape)` on compatible device with model.
    enable_backprop : bool, default = False
        whether to enable backprop through the Js and f w.r.t. x

    Returns
    -------
    Js : torch.Tensor
        Jacobians `(batch, parameters, outputs)`
    f : torch.Tensor
        output function `(batch, outputs)`
    """
    if isinstance(x, MutableMapping):
        raise ValueError("BackPACK backend does not support dict-like inputs!")

    model = extend(self.model)
    to_stack = []
    for i in range(model.output_size):
        model.zero_grad()
        out = model(x)
        with backpack(BatchGrad()):
            if model.output_size > 1:
                out[:, i].sum().backward(
                    create_graph=enable_backprop, retain_graph=enable_backprop
                )
            else:
                out.sum().backward(
                    create_graph=enable_backprop, retain_graph=enable_backprop
                )
            to_cat = []
            for param in model.parameters():
                to_cat.append(param.grad_batch.reshape(x.shape[0], -1))
                delattr(param, "grad_batch")
            Jk = torch.cat(to_cat, dim=1)
            if self.subnetwork_indices is not None:
                Jk = Jk[:, self.subnetwork_indices]
        to_stack.append(Jk)
        if i == 0:
            f = out

    model.zero_grad()
    CTX.remove_hooks()
    _cleanup(model)
    if model.output_size > 1:
        J = torch.stack(to_stack, dim=2).transpose(1, 2)
    else:
        J = Jk.unsqueeze(-1).transpose(1, 2)

    return (J, f) if enable_backprop else (J.detach(), f.detach())

last_layer_jacobians #

last_layer_jacobians(x: Tensor | MutableMapping[str, Tensor | Any], enable_backprop: bool = False) -> tuple[Tensor, Tensor]

Compute Jacobians \(\nabla_{\theta_\textrm{last}} f(x;\theta_\textrm{last})\) only at current last-layer parameter \(\theta_{\textrm{last}}\).

Parameters:

  • x (Tensor) –
  • enable_backprop (bool, default: False ) –

Returns:

  • Js ( Tensor ) –

    Jacobians (batch, outputs, last-layer-parameters)

  • f ( Tensor ) –

    output function (batch, outputs)

Source code in laplace/curvature/curvature.py
def last_layer_jacobians(
    self,
    x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
    enable_backprop: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Compute Jacobians \\(\\nabla_{\\theta_\\textrm{last}} f(x;\\theta_\\textrm{last})\\)
    only at current last-layer parameter \\(\\theta_{\\textrm{last}}\\).

    Parameters
    ----------
    x : torch.Tensor
    enable_backprop : bool, default=False

    Returns
    -------
    Js : torch.Tensor
        Jacobians `(batch, outputs, last-layer-parameters)`
    f : torch.Tensor
        output function `(batch, outputs)`
    """
    f, phi = self.model.forward_with_features(x)
    bsize = phi.shape[0]
    output_size = int(f.numel() / bsize)

    # calculate Jacobians using the feature vector 'phi'
    identity = (
        torch.eye(output_size, device=next(self.model.parameters()).device)
        .unsqueeze(0)
        .tile(bsize, 1, 1)
    )
    # Jacobians are batch x output x params
    Js = torch.einsum("kp,kij->kijp", phi, identity).reshape(bsize, output_size, -1)
    if self.model.last_layer.bias is not None:
        Js = torch.cat([Js, identity], dim=2)

    return (Js, f) if enable_backprop else (Js.detach(), f.detach())

gradients #

gradients(x: Tensor | MutableMapping[str, Tensor | Any], y: Tensor) -> tuple[Tensor, Tensor]

Compute gradients \(\nabla_\theta \ell(f(x;\theta, y)\) at current parameter \(\theta\) using Backpack's BatchGrad. Note that BackPACK doesn't play well with torch.func, so this method has to be overridden.

Parameters:

  • x (Tensor) –

    input data (batch, input_shape) on compatible device with model.

  • y (Tensor) –

Returns:

  • Gs ( Tensor ) –

    gradients (batch, parameters)

  • loss ( Tensor ) –
Source code in laplace/curvature/backpack.py
def gradients(
    self, x: torch.Tensor | MutableMapping[str, torch.Tensor | Any], y: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
    """Compute gradients \\(\\nabla_\\theta \\ell(f(x;\\theta, y)\\) at current parameter
    \\(\\theta\\) using Backpack's BatchGrad. Note that BackPACK doesn't play well
    with torch.func, so this method has to be overridden.

    Parameters
    ----------
    x : torch.Tensor
        input data `(batch, input_shape)` on compatible device with model.
    y : torch.Tensor

    Returns
    -------
    Gs : torch.Tensor
        gradients `(batch, parameters)`
    loss : torch.Tensor
    """
    f = self.model(x)
    loss = self.lossfunc(f, y)
    with backpack(BatchGrad()):
        loss.backward()
    Gs = torch.cat(
        [p.grad_batch.data.flatten(start_dim=1) for p in self._model.parameters()],
        dim=1,
    )
    if self.subnetwork_indices is not None:
        Gs = Gs[:, self.subnetwork_indices]
    return Gs, loss

full #

full(x: Tensor | MutableMapping[str, Tensor | Any], y: Tensor, **kwargs: dict[str, Any]) -> tuple[Tensor, Tensor]

Compute the full EF \(P \times P\) matrix as Hessian approximation \(H_{ef}\) with respect to parameters \(\theta \in \mathbb{R}^P\). For last-layer, reduced to \(\theta_{last}\)

Parameters:

  • x (Tensor) –

    input data (batch, input_shape)

  • y (Tensor) –

    labels (batch, label_shape)

Returns:

  • loss ( Tensor ) –
  • H_ef ( Tensor ) –

    EF (parameters, parameters)

Source code in laplace/curvature/curvature.py
def full(
    self,
    x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
    y: torch.Tensor,
    **kwargs: dict[str, Any],
) -> tuple[torch.Tensor, torch.Tensor]:
    """Compute the full EF \\(P \\times P\\) matrix as Hessian approximation
    \\(H_{ef}\\) with respect to parameters \\(\\theta \\in \\mathbb{R}^P\\).
    For last-layer, reduced to \\(\\theta_{last}\\)

    Parameters
    ----------
    x : torch.Tensor
        input data `(batch, input_shape)`
    y : torch.Tensor
        labels `(batch, label_shape)`

    Returns
    -------
    loss : torch.Tensor
    H_ef : torch.Tensor
        EF `(parameters, parameters)`
    """
    Gs, loss = self.gradients(x, y)
    Gs, loss = Gs.detach(), loss.detach()
    H_ef = torch.einsum("bp,bq->pq", Gs, Gs)
    return self.factor * loss.detach(), self.factor * H_ef

CurvlinopsInterface #

CurvlinopsInterface(model: Module, likelihood: Likelihood | str, last_layer: bool = False, subnetwork_indices: LongTensor | None = None, dict_key_x: str = 'input_ids', dict_key_y: str = 'labels')

Bases: CurvatureInterface

Interface for Curvlinops backend. https://github.com/f-dangel/curvlinops

Source code in laplace/curvature/curvlinops.py
def __init__(
    self,
    model: nn.Module,
    likelihood: Likelihood | str,
    last_layer: bool = False,
    subnetwork_indices: torch.LongTensor | None = None,
    dict_key_x: str = "input_ids",
    dict_key_y: str = "labels",
) -> None:
    super().__init__(
        model, likelihood, last_layer, subnetwork_indices, dict_key_x, dict_key_y
    )

jacobians #

jacobians(x: Tensor | MutableMapping[str, Tensor | Any], enable_backprop: bool = False) -> tuple[Tensor, Tensor]

Compute Jacobians \(\nabla_{\theta} f(x;\theta)\) at current parameter \(\theta\), via torch.func.

Parameters:

  • x (Tensor) –

    input data (batch, input_shape) on compatible device with model.

  • enable_backprop (bool, default: = False ) –

    whether to enable backprop through the Js and f w.r.t. x

Returns:

  • Js ( Tensor ) –

    Jacobians (batch, parameters, outputs)

  • f ( Tensor ) –

    output function (batch, outputs)

Source code in laplace/curvature/curvature.py
def jacobians(
    self,
    x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
    enable_backprop: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Compute Jacobians \\(\\nabla_{\\theta} f(x;\\theta)\\) at current parameter \\(\\theta\\),
    via torch.func.

    Parameters
    ----------
    x : torch.Tensor
        input data `(batch, input_shape)` on compatible device with model.
    enable_backprop : bool, default = False
        whether to enable backprop through the Js and f w.r.t. x

    Returns
    -------
    Js : torch.Tensor
        Jacobians `(batch, parameters, outputs)`
    f : torch.Tensor
        output function `(batch, outputs)`
    """

    def model_fn_params_only(params_dict, buffers_dict):
        out = torch.func.functional_call(self.model, (params_dict, buffers_dict), x)
        return out, out

    Js, f = torch.func.jacrev(model_fn_params_only, has_aux=True)(
        self.params_dict, self.buffers_dict
    )

    # Concatenate over flattened parameters
    Js = [
        j.flatten(start_dim=-p.dim())
        for j, p in zip(Js.values(), self.params_dict.values())
    ]
    Js = torch.cat(Js, dim=-1)

    if self.subnetwork_indices is not None:
        Js = Js[:, :, self.subnetwork_indices]

    return (Js, f) if enable_backprop else (Js.detach(), f.detach())

last_layer_jacobians #

last_layer_jacobians(x: Tensor | MutableMapping[str, Tensor | Any], enable_backprop: bool = False) -> tuple[Tensor, Tensor]

Compute Jacobians \(\nabla_{\theta_\textrm{last}} f(x;\theta_\textrm{last})\) only at current last-layer parameter \(\theta_{\textrm{last}}\).

Parameters:

  • x (Tensor) –
  • enable_backprop (bool, default: False ) –

Returns:

  • Js ( Tensor ) –

    Jacobians (batch, outputs, last-layer-parameters)

  • f ( Tensor ) –

    output function (batch, outputs)

Source code in laplace/curvature/curvature.py
def last_layer_jacobians(
    self,
    x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
    enable_backprop: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Compute Jacobians \\(\\nabla_{\\theta_\\textrm{last}} f(x;\\theta_\\textrm{last})\\)
    only at current last-layer parameter \\(\\theta_{\\textrm{last}}\\).

    Parameters
    ----------
    x : torch.Tensor
    enable_backprop : bool, default=False

    Returns
    -------
    Js : torch.Tensor
        Jacobians `(batch, outputs, last-layer-parameters)`
    f : torch.Tensor
        output function `(batch, outputs)`
    """
    f, phi = self.model.forward_with_features(x)
    bsize = phi.shape[0]
    output_size = int(f.numel() / bsize)

    # calculate Jacobians using the feature vector 'phi'
    identity = (
        torch.eye(output_size, device=next(self.model.parameters()).device)
        .unsqueeze(0)
        .tile(bsize, 1, 1)
    )
    # Jacobians are batch x output x params
    Js = torch.einsum("kp,kij->kijp", phi, identity).reshape(bsize, output_size, -1)
    if self.model.last_layer.bias is not None:
        Js = torch.cat([Js, identity], dim=2)

    return (Js, f) if enable_backprop else (Js.detach(), f.detach())

gradients #

gradients(x: Tensor | MutableMapping[str, Tensor | Any], y: Tensor) -> tuple[Tensor, Tensor]

Compute batch gradients \(\nabla_\theta \ell(f(x;\theta, y)\) at current parameter \(\theta\).

Parameters:

  • x (Tensor) –

    input data (batch, input_shape) on compatible device with model.

  • y (Tensor) –

Returns:

  • Gs ( Tensor ) –

    gradients (batch, parameters)

  • loss ( Tensor ) –
Source code in laplace/curvature/curvature.py
def gradients(
    self, x: torch.Tensor | MutableMapping[str, torch.Tensor | Any], y: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
    """Compute batch gradients \\(\\nabla_\\theta \\ell(f(x;\\theta, y)\\) at
    current parameter \\(\\theta\\).

    Parameters
    ----------
    x : torch.Tensor
        input data `(batch, input_shape)` on compatible device with model.
    y : torch.Tensor

    Returns
    -------
    Gs : torch.Tensor
        gradients `(batch, parameters)`
    loss : torch.Tensor
    """

    def loss_single(x, y, params_dict, buffers_dict):
        """Compute the gradient for a single sample."""
        x, y = x.unsqueeze(0), y.unsqueeze(0)  # vmap removes the batch dimension
        output = torch.func.functional_call(
            self.model, (params_dict, buffers_dict), x
        )
        loss = torch.func.functional_call(self.lossfunc, {}, (output, y))
        return loss, loss

    grad_fn = torch.func.grad(loss_single, argnums=2, has_aux=True)
    batch_grad_fn = torch.func.vmap(grad_fn, in_dims=(0, 0, None, None))

    batch_grad, batch_loss = batch_grad_fn(
        x, y, self.params_dict, self.buffers_dict
    )
    Gs = torch.cat([bg.flatten(start_dim=1) for bg in batch_grad.values()], dim=1)

    if self.subnetwork_indices is not None:
        Gs = Gs[:, self.subnetwork_indices]

    loss = batch_loss.sum(0)

    return Gs, loss

diag #

diag(x: Tensor | MutableMapping[str, Tensor | Any], y: Tensor, **kwargs: dict[str, Any])

Compute a diagonal Hessian approximation to \(H\) and is represented as a vector of the dimensionality of parameters \(\theta\).

Parameters:

  • x (Tensor) –

    input data (batch, input_shape)

  • y (Tensor) –

    labels (batch, label_shape)

Returns:

  • loss ( Tensor ) –
  • H ( Tensor ) –

    vector representing the diagonal of H

Source code in laplace/curvature/curvature.py
def diag(
    self,
    x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
    y: torch.Tensor,
    **kwargs: dict[str, Any],
):
    """Compute a diagonal Hessian approximation to \\(H\\) and is represented as a
    vector of the dimensionality of parameters \\(\\theta\\).

    Parameters
    ----------
    x : torch.Tensor
        input data `(batch, input_shape)`
    y : torch.Tensor
        labels `(batch, label_shape)`

    Returns
    -------
    loss : torch.Tensor
    H : torch.Tensor
        vector representing the diagonal of H
    """
    raise NotImplementedError

CurvlinopsGGN #

CurvlinopsGGN(model: Module, likelihood: Likelihood | str, last_layer: bool = False, subnetwork_indices: LongTensor | None = None, dict_key_x: str = 'input_ids', dict_key_y: str = 'labels', stochastic: bool = False)

Bases: CurvlinopsInterface, GGNInterface

Implementation of the GGNInterface using Curvlinops.

Source code in laplace/curvature/curvlinops.py
def __init__(
    self,
    model: nn.Module,
    likelihood: Likelihood | str,
    last_layer: bool = False,
    subnetwork_indices: torch.LongTensor | None = None,
    dict_key_x: str = "input_ids",
    dict_key_y: str = "labels",
    stochastic: bool = False,
) -> None:
    super().__init__(
        model, likelihood, last_layer, subnetwork_indices, dict_key_x, dict_key_y
    )
    self.stochastic = stochastic

jacobians #

jacobians(x: Tensor | MutableMapping[str, Tensor | Any], enable_backprop: bool = False) -> tuple[Tensor, Tensor]

Compute Jacobians \(\nabla_{\theta} f(x;\theta)\) at current parameter \(\theta\), via torch.func.

Parameters:

  • x (Tensor) –

    input data (batch, input_shape) on compatible device with model.

  • enable_backprop (bool, default: = False ) –

    whether to enable backprop through the Js and f w.r.t. x

Returns:

  • Js ( Tensor ) –

    Jacobians (batch, parameters, outputs)

  • f ( Tensor ) –

    output function (batch, outputs)

Source code in laplace/curvature/curvature.py
def jacobians(
    self,
    x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
    enable_backprop: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Compute Jacobians \\(\\nabla_{\\theta} f(x;\\theta)\\) at current parameter \\(\\theta\\),
    via torch.func.

    Parameters
    ----------
    x : torch.Tensor
        input data `(batch, input_shape)` on compatible device with model.
    enable_backprop : bool, default = False
        whether to enable backprop through the Js and f w.r.t. x

    Returns
    -------
    Js : torch.Tensor
        Jacobians `(batch, parameters, outputs)`
    f : torch.Tensor
        output function `(batch, outputs)`
    """

    def model_fn_params_only(params_dict, buffers_dict):
        out = torch.func.functional_call(self.model, (params_dict, buffers_dict), x)
        return out, out

    Js, f = torch.func.jacrev(model_fn_params_only, has_aux=True)(
        self.params_dict, self.buffers_dict
    )

    # Concatenate over flattened parameters
    Js = [
        j.flatten(start_dim=-p.dim())
        for j, p in zip(Js.values(), self.params_dict.values())
    ]
    Js = torch.cat(Js, dim=-1)

    if self.subnetwork_indices is not None:
        Js = Js[:, :, self.subnetwork_indices]

    return (Js, f) if enable_backprop else (Js.detach(), f.detach())

last_layer_jacobians #

last_layer_jacobians(x: Tensor | MutableMapping[str, Tensor | Any], enable_backprop: bool = False) -> tuple[Tensor, Tensor]

Compute Jacobians \(\nabla_{\theta_\textrm{last}} f(x;\theta_\textrm{last})\) only at current last-layer parameter \(\theta_{\textrm{last}}\).

Parameters:

  • x (Tensor) –
  • enable_backprop (bool, default: False ) –

Returns:

  • Js ( Tensor ) –

    Jacobians (batch, outputs, last-layer-parameters)

  • f ( Tensor ) –

    output function (batch, outputs)

Source code in laplace/curvature/curvature.py
def last_layer_jacobians(
    self,
    x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
    enable_backprop: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Compute Jacobians \\(\\nabla_{\\theta_\\textrm{last}} f(x;\\theta_\\textrm{last})\\)
    only at current last-layer parameter \\(\\theta_{\\textrm{last}}\\).

    Parameters
    ----------
    x : torch.Tensor
    enable_backprop : bool, default=False

    Returns
    -------
    Js : torch.Tensor
        Jacobians `(batch, outputs, last-layer-parameters)`
    f : torch.Tensor
        output function `(batch, outputs)`
    """
    f, phi = self.model.forward_with_features(x)
    bsize = phi.shape[0]
    output_size = int(f.numel() / bsize)

    # calculate Jacobians using the feature vector 'phi'
    identity = (
        torch.eye(output_size, device=next(self.model.parameters()).device)
        .unsqueeze(0)
        .tile(bsize, 1, 1)
    )
    # Jacobians are batch x output x params
    Js = torch.einsum("kp,kij->kijp", phi, identity).reshape(bsize, output_size, -1)
    if self.model.last_layer.bias is not None:
        Js = torch.cat([Js, identity], dim=2)

    return (Js, f) if enable_backprop else (Js.detach(), f.detach())

gradients #

gradients(x: Tensor | MutableMapping[str, Tensor | Any], y: Tensor) -> tuple[Tensor, Tensor]

Compute batch gradients \(\nabla_\theta \ell(f(x;\theta, y)\) at current parameter \(\theta\).

Parameters:

  • x (Tensor) –

    input data (batch, input_shape) on compatible device with model.

  • y (Tensor) –

Returns:

  • Gs ( Tensor ) –

    gradients (batch, parameters)

  • loss ( Tensor ) –
Source code in laplace/curvature/curvature.py
def gradients(
    self, x: torch.Tensor | MutableMapping[str, torch.Tensor | Any], y: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
    """Compute batch gradients \\(\\nabla_\\theta \\ell(f(x;\\theta, y)\\) at
    current parameter \\(\\theta\\).

    Parameters
    ----------
    x : torch.Tensor
        input data `(batch, input_shape)` on compatible device with model.
    y : torch.Tensor

    Returns
    -------
    Gs : torch.Tensor
        gradients `(batch, parameters)`
    loss : torch.Tensor
    """

    def loss_single(x, y, params_dict, buffers_dict):
        """Compute the gradient for a single sample."""
        x, y = x.unsqueeze(0), y.unsqueeze(0)  # vmap removes the batch dimension
        output = torch.func.functional_call(
            self.model, (params_dict, buffers_dict), x
        )
        loss = torch.func.functional_call(self.lossfunc, {}, (output, y))
        return loss, loss

    grad_fn = torch.func.grad(loss_single, argnums=2, has_aux=True)
    batch_grad_fn = torch.func.vmap(grad_fn, in_dims=(0, 0, None, None))

    batch_grad, batch_loss = batch_grad_fn(
        x, y, self.params_dict, self.buffers_dict
    )
    Gs = torch.cat([bg.flatten(start_dim=1) for bg in batch_grad.values()], dim=1)

    if self.subnetwork_indices is not None:
        Gs = Gs[:, self.subnetwork_indices]

    loss = batch_loss.sum(0)

    return Gs, loss

_get_mc_functional_fisher #

_get_mc_functional_fisher(f: Tensor) -> Tensor

Approximate the Fisher's middle matrix (expected outer product of the functional gradient) using MC integral with self.num_samples many samples.

Source code in laplace/curvature/curvature.py
def _get_mc_functional_fisher(self, f: torch.Tensor) -> torch.Tensor:
    """Approximate the Fisher's middle matrix (expected outer product of the functional gradient)
    using MC integral with `self.num_samples` many samples.
    """
    F = 0

    for _ in range(self.num_samples):
        if self.likelihood == "regression":
            y_sample = f + torch.randn(f.shape, device=f.device)  # N(y | f, 1)
            grad_sample = f - y_sample  # functional MSE grad
        else:  # classification with softmax
            y_sample = torch.distributions.Multinomial(logits=f).sample()
            # First functional derivative of the loglik is p - y
            p = torch.softmax(f, dim=-1)
            grad_sample = p - y_sample

        F += (
            1
            / self.num_samples
            * torch.einsum("bc,bk->bck", grad_sample, grad_sample)
        )

    return F

CurvlinopsEF #

CurvlinopsEF(model: Module, likelihood: Likelihood | str, last_layer: bool = False, subnetwork_indices: LongTensor | None = None, dict_key_x: str = 'input_ids', dict_key_y: str = 'labels')

Bases: CurvlinopsInterface, EFInterface

Implementation of EFInterface using Curvlinops.

Source code in laplace/curvature/curvlinops.py
def __init__(
    self,
    model: nn.Module,
    likelihood: Likelihood | str,
    last_layer: bool = False,
    subnetwork_indices: torch.LongTensor | None = None,
    dict_key_x: str = "input_ids",
    dict_key_y: str = "labels",
) -> None:
    super().__init__(
        model, likelihood, last_layer, subnetwork_indices, dict_key_x, dict_key_y
    )

jacobians #

jacobians(x: Tensor | MutableMapping[str, Tensor | Any], enable_backprop: bool = False) -> tuple[Tensor, Tensor]

Compute Jacobians \(\nabla_{\theta} f(x;\theta)\) at current parameter \(\theta\), via torch.func.

Parameters:

  • x (Tensor) –

    input data (batch, input_shape) on compatible device with model.

  • enable_backprop (bool, default: = False ) –

    whether to enable backprop through the Js and f w.r.t. x

Returns:

  • Js ( Tensor ) –

    Jacobians (batch, parameters, outputs)

  • f ( Tensor ) –

    output function (batch, outputs)

Source code in laplace/curvature/curvature.py
def jacobians(
    self,
    x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
    enable_backprop: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Compute Jacobians \\(\\nabla_{\\theta} f(x;\\theta)\\) at current parameter \\(\\theta\\),
    via torch.func.

    Parameters
    ----------
    x : torch.Tensor
        input data `(batch, input_shape)` on compatible device with model.
    enable_backprop : bool, default = False
        whether to enable backprop through the Js and f w.r.t. x

    Returns
    -------
    Js : torch.Tensor
        Jacobians `(batch, parameters, outputs)`
    f : torch.Tensor
        output function `(batch, outputs)`
    """

    def model_fn_params_only(params_dict, buffers_dict):
        out = torch.func.functional_call(self.model, (params_dict, buffers_dict), x)
        return out, out

    Js, f = torch.func.jacrev(model_fn_params_only, has_aux=True)(
        self.params_dict, self.buffers_dict
    )

    # Concatenate over flattened parameters
    Js = [
        j.flatten(start_dim=-p.dim())
        for j, p in zip(Js.values(), self.params_dict.values())
    ]
    Js = torch.cat(Js, dim=-1)

    if self.subnetwork_indices is not None:
        Js = Js[:, :, self.subnetwork_indices]

    return (Js, f) if enable_backprop else (Js.detach(), f.detach())

last_layer_jacobians #

last_layer_jacobians(x: Tensor | MutableMapping[str, Tensor | Any], enable_backprop: bool = False) -> tuple[Tensor, Tensor]

Compute Jacobians \(\nabla_{\theta_\textrm{last}} f(x;\theta_\textrm{last})\) only at current last-layer parameter \(\theta_{\textrm{last}}\).

Parameters:

  • x (Tensor) –
  • enable_backprop (bool, default: False ) –

Returns:

  • Js ( Tensor ) –

    Jacobians (batch, outputs, last-layer-parameters)

  • f ( Tensor ) –

    output function (batch, outputs)

Source code in laplace/curvature/curvature.py
def last_layer_jacobians(
    self,
    x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
    enable_backprop: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Compute Jacobians \\(\\nabla_{\\theta_\\textrm{last}} f(x;\\theta_\\textrm{last})\\)
    only at current last-layer parameter \\(\\theta_{\\textrm{last}}\\).

    Parameters
    ----------
    x : torch.Tensor
    enable_backprop : bool, default=False

    Returns
    -------
    Js : torch.Tensor
        Jacobians `(batch, outputs, last-layer-parameters)`
    f : torch.Tensor
        output function `(batch, outputs)`
    """
    f, phi = self.model.forward_with_features(x)
    bsize = phi.shape[0]
    output_size = int(f.numel() / bsize)

    # calculate Jacobians using the feature vector 'phi'
    identity = (
        torch.eye(output_size, device=next(self.model.parameters()).device)
        .unsqueeze(0)
        .tile(bsize, 1, 1)
    )
    # Jacobians are batch x output x params
    Js = torch.einsum("kp,kij->kijp", phi, identity).reshape(bsize, output_size, -1)
    if self.model.last_layer.bias is not None:
        Js = torch.cat([Js, identity], dim=2)

    return (Js, f) if enable_backprop else (Js.detach(), f.detach())

gradients #

gradients(x: Tensor | MutableMapping[str, Tensor | Any], y: Tensor) -> tuple[Tensor, Tensor]

Compute batch gradients \(\nabla_\theta \ell(f(x;\theta, y)\) at current parameter \(\theta\).

Parameters:

  • x (Tensor) –

    input data (batch, input_shape) on compatible device with model.

  • y (Tensor) –

Returns:

  • Gs ( Tensor ) –

    gradients (batch, parameters)

  • loss ( Tensor ) –
Source code in laplace/curvature/curvature.py
def gradients(
    self, x: torch.Tensor | MutableMapping[str, torch.Tensor | Any], y: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
    """Compute batch gradients \\(\\nabla_\\theta \\ell(f(x;\\theta, y)\\) at
    current parameter \\(\\theta\\).

    Parameters
    ----------
    x : torch.Tensor
        input data `(batch, input_shape)` on compatible device with model.
    y : torch.Tensor

    Returns
    -------
    Gs : torch.Tensor
        gradients `(batch, parameters)`
    loss : torch.Tensor
    """

    def loss_single(x, y, params_dict, buffers_dict):
        """Compute the gradient for a single sample."""
        x, y = x.unsqueeze(0), y.unsqueeze(0)  # vmap removes the batch dimension
        output = torch.func.functional_call(
            self.model, (params_dict, buffers_dict), x
        )
        loss = torch.func.functional_call(self.lossfunc, {}, (output, y))
        return loss, loss

    grad_fn = torch.func.grad(loss_single, argnums=2, has_aux=True)
    batch_grad_fn = torch.func.vmap(grad_fn, in_dims=(0, 0, None, None))

    batch_grad, batch_loss = batch_grad_fn(
        x, y, self.params_dict, self.buffers_dict
    )
    Gs = torch.cat([bg.flatten(start_dim=1) for bg in batch_grad.values()], dim=1)

    if self.subnetwork_indices is not None:
        Gs = Gs[:, self.subnetwork_indices]

    loss = batch_loss.sum(0)

    return Gs, loss

CurvlinopsHessian #

CurvlinopsHessian(model: Module, likelihood: Likelihood | str, last_layer: bool = False, subnetwork_indices: LongTensor | None = None, dict_key_x: str = 'input_ids', dict_key_y: str = 'labels')

Bases: CurvlinopsInterface

Implementation of the full Hessian using Curvlinops.

Source code in laplace/curvature/curvlinops.py
def __init__(
    self,
    model: nn.Module,
    likelihood: Likelihood | str,
    last_layer: bool = False,
    subnetwork_indices: torch.LongTensor | None = None,
    dict_key_x: str = "input_ids",
    dict_key_y: str = "labels",
) -> None:
    super().__init__(
        model, likelihood, last_layer, subnetwork_indices, dict_key_x, dict_key_y
    )

jacobians #

jacobians(x: Tensor | MutableMapping[str, Tensor | Any], enable_backprop: bool = False) -> tuple[Tensor, Tensor]

Compute Jacobians \(\nabla_{\theta} f(x;\theta)\) at current parameter \(\theta\), via torch.func.

Parameters:

  • x (Tensor) –

    input data (batch, input_shape) on compatible device with model.

  • enable_backprop (bool, default: = False ) –

    whether to enable backprop through the Js and f w.r.t. x

Returns:

  • Js ( Tensor ) –

    Jacobians (batch, parameters, outputs)

  • f ( Tensor ) –

    output function (batch, outputs)

Source code in laplace/curvature/curvature.py
def jacobians(
    self,
    x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
    enable_backprop: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Compute Jacobians \\(\\nabla_{\\theta} f(x;\\theta)\\) at current parameter \\(\\theta\\),
    via torch.func.

    Parameters
    ----------
    x : torch.Tensor
        input data `(batch, input_shape)` on compatible device with model.
    enable_backprop : bool, default = False
        whether to enable backprop through the Js and f w.r.t. x

    Returns
    -------
    Js : torch.Tensor
        Jacobians `(batch, parameters, outputs)`
    f : torch.Tensor
        output function `(batch, outputs)`
    """

    def model_fn_params_only(params_dict, buffers_dict):
        out = torch.func.functional_call(self.model, (params_dict, buffers_dict), x)
        return out, out

    Js, f = torch.func.jacrev(model_fn_params_only, has_aux=True)(
        self.params_dict, self.buffers_dict
    )

    # Concatenate over flattened parameters
    Js = [
        j.flatten(start_dim=-p.dim())
        for j, p in zip(Js.values(), self.params_dict.values())
    ]
    Js = torch.cat(Js, dim=-1)

    if self.subnetwork_indices is not None:
        Js = Js[:, :, self.subnetwork_indices]

    return (Js, f) if enable_backprop else (Js.detach(), f.detach())

last_layer_jacobians #

last_layer_jacobians(x: Tensor | MutableMapping[str, Tensor | Any], enable_backprop: bool = False) -> tuple[Tensor, Tensor]

Compute Jacobians \(\nabla_{\theta_\textrm{last}} f(x;\theta_\textrm{last})\) only at current last-layer parameter \(\theta_{\textrm{last}}\).

Parameters:

  • x (Tensor) –
  • enable_backprop (bool, default: False ) –

Returns:

  • Js ( Tensor ) –

    Jacobians (batch, outputs, last-layer-parameters)

  • f ( Tensor ) –

    output function (batch, outputs)

Source code in laplace/curvature/curvature.py
def last_layer_jacobians(
    self,
    x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
    enable_backprop: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Compute Jacobians \\(\\nabla_{\\theta_\\textrm{last}} f(x;\\theta_\\textrm{last})\\)
    only at current last-layer parameter \\(\\theta_{\\textrm{last}}\\).

    Parameters
    ----------
    x : torch.Tensor
    enable_backprop : bool, default=False

    Returns
    -------
    Js : torch.Tensor
        Jacobians `(batch, outputs, last-layer-parameters)`
    f : torch.Tensor
        output function `(batch, outputs)`
    """
    f, phi = self.model.forward_with_features(x)
    bsize = phi.shape[0]
    output_size = int(f.numel() / bsize)

    # calculate Jacobians using the feature vector 'phi'
    identity = (
        torch.eye(output_size, device=next(self.model.parameters()).device)
        .unsqueeze(0)
        .tile(bsize, 1, 1)
    )
    # Jacobians are batch x output x params
    Js = torch.einsum("kp,kij->kijp", phi, identity).reshape(bsize, output_size, -1)
    if self.model.last_layer.bias is not None:
        Js = torch.cat([Js, identity], dim=2)

    return (Js, f) if enable_backprop else (Js.detach(), f.detach())

gradients #

gradients(x: Tensor | MutableMapping[str, Tensor | Any], y: Tensor) -> tuple[Tensor, Tensor]

Compute batch gradients \(\nabla_\theta \ell(f(x;\theta, y)\) at current parameter \(\theta\).

Parameters:

  • x (Tensor) –

    input data (batch, input_shape) on compatible device with model.

  • y (Tensor) –

Returns:

  • Gs ( Tensor ) –

    gradients (batch, parameters)

  • loss ( Tensor ) –
Source code in laplace/curvature/curvature.py
def gradients(
    self, x: torch.Tensor | MutableMapping[str, torch.Tensor | Any], y: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
    """Compute batch gradients \\(\\nabla_\\theta \\ell(f(x;\\theta, y)\\) at
    current parameter \\(\\theta\\).

    Parameters
    ----------
    x : torch.Tensor
        input data `(batch, input_shape)` on compatible device with model.
    y : torch.Tensor

    Returns
    -------
    Gs : torch.Tensor
        gradients `(batch, parameters)`
    loss : torch.Tensor
    """

    def loss_single(x, y, params_dict, buffers_dict):
        """Compute the gradient for a single sample."""
        x, y = x.unsqueeze(0), y.unsqueeze(0)  # vmap removes the batch dimension
        output = torch.func.functional_call(
            self.model, (params_dict, buffers_dict), x
        )
        loss = torch.func.functional_call(self.lossfunc, {}, (output, y))
        return loss, loss

    grad_fn = torch.func.grad(loss_single, argnums=2, has_aux=True)
    batch_grad_fn = torch.func.vmap(grad_fn, in_dims=(0, 0, None, None))

    batch_grad, batch_loss = batch_grad_fn(
        x, y, self.params_dict, self.buffers_dict
    )
    Gs = torch.cat([bg.flatten(start_dim=1) for bg in batch_grad.values()], dim=1)

    if self.subnetwork_indices is not None:
        Gs = Gs[:, self.subnetwork_indices]

    loss = batch_loss.sum(0)

    return Gs, loss

diag #

diag(x: Tensor | MutableMapping[str, Tensor | Any], y: Tensor, **kwargs: dict[str, Any])

Compute a diagonal Hessian approximation to \(H\) and is represented as a vector of the dimensionality of parameters \(\theta\).

Parameters:

  • x (Tensor) –

    input data (batch, input_shape)

  • y (Tensor) –

    labels (batch, label_shape)

Returns:

  • loss ( Tensor ) –
  • H ( Tensor ) –

    vector representing the diagonal of H

Source code in laplace/curvature/curvature.py
def diag(
    self,
    x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
    y: torch.Tensor,
    **kwargs: dict[str, Any],
):
    """Compute a diagonal Hessian approximation to \\(H\\) and is represented as a
    vector of the dimensionality of parameters \\(\\theta\\).

    Parameters
    ----------
    x : torch.Tensor
        input data `(batch, input_shape)`
    y : torch.Tensor
        labels `(batch, label_shape)`

    Returns
    -------
    loss : torch.Tensor
    H : torch.Tensor
        vector representing the diagonal of H
    """
    raise NotImplementedError