Skip to content

laplace.utils #

SoDSampler #

SoDSampler(N, M, seed: int = 0)

Bases: Sampler

Source code in laplace/utils/utils.py
def __init__(self, N, M, seed: int = 0):
    rng = np.random.default_rng(seed)
    self.indices = torch.tensor(rng.choice(list(range(N)), M, replace=False))

FeatureExtractor #

FeatureExtractor(model: Module, last_layer_name: str | None = None, enable_backprop: bool = False, feature_reduction: FeatureReduction | str | None = None)

Bases: Module

Feature extractor for a PyTorch neural network. A wrapper which can return the output of the penultimate layer in addition to the output of the last layer for each forward pass. If the name of the last layer is not known, it can determine it automatically. It assumes that the last layer is linear and that for every forward pass the last layer is the same. If the name of the last layer is known, it can be passed as a parameter at initilization; this is the safest way to use this class. Based on https://gist.github.com/fkodom/27ed045c9051a39102e8bcf4ce31df76.

Parameters:

  • model (Module) –

    PyTorch model

  • last_layer_name (str, default: None ) –

    if the name of the last layer is already known, otherwise it will be determined automatically.

  • enable_backprop (bool, default: False ) –

    whether to enable backprop through the feature extactor to get the gradients of the inputs. Useful for e.g. Bayesian optimization.

  • feature_reduction (FeatureReduction | str | None, default: None ) –

    when the last-layer features is a tensor of dim >= 3, this tells how to reduce it into a dim-2 tensor. E.g. in LLMs for non-language modeling problems, the penultultimate output is a tensor of shape (batch_size, seq_len, embd_dim). But the last layer maps (batch_size, embd_dim) to (batch_size, n_classes). Note: Make sure that this option faithfully reflects the reduction in the model definition. When inputting a string, available options are {'pick_first', 'pick_last', 'average'}.

Source code in laplace/utils/feature_extractor.py
def __init__(
    self,
    model: nn.Module,
    last_layer_name: str | None = None,
    enable_backprop: bool = False,
    feature_reduction: FeatureReduction | str | None = None,
) -> None:
    if feature_reduction is not None and feature_reduction not in [
        fr.value for fr in FeatureReduction
    ]:
        raise ValueError(
            "`feature_reduction` must take value in the `FeatureReduction enum` or "
            "one of `{'pick_first', 'pick_last', 'average'}`!"
        )

    super().__init__()
    self.model: nn.Module = model
    self._features: dict[str, torch.Tensor] = dict()
    self.enable_backprop: bool = enable_backprop
    self.feature_reduction: FeatureReduction | None = feature_reduction

    self.last_layer: nn.Module | None
    if last_layer_name is None:
        self.last_layer = None
    else:
        self.set_last_layer(last_layer_name)

forward #

forward(x: Tensor | MutableMapping[str, Tensor | Any]) -> Tensor

Forward pass. If the last layer is not known yet, it will be determined when this function is called for the first time.

Parameters:

  • x (torch.Tensor or a dict-like object containing the input tensors) –

    one batch of data to use as input for the forward pass

Source code in laplace/utils/feature_extractor.py
def forward(
    self, x: torch.Tensor | MutableMapping[str, torch.Tensor | Any]
) -> torch.Tensor:
    """Forward pass. If the last layer is not known yet, it will be
    determined when this function is called for the first time.

    Parameters
    ----------
    x : torch.Tensor or a dict-like object containing the input tensors
        one batch of data to use as input for the forward pass
    """
    if self.last_layer is None:
        # if this is the first forward pass and last layer is unknown
        out = self.find_last_layer(x)
    else:
        # if last and penultimate layers are already known
        out = self.model(x)
    return out

forward_with_features #

forward_with_features(x: Tensor | MutableMapping[str, Tensor | Any]) -> tuple[Tensor, Tensor]

Forward pass which returns the output of the penultimate layer along with the output of the last layer. If the last layer is not known yet, it will be determined when this function is called for the first time.

Parameters:

  • x (torch.Tensor or a dict-like object containing the input tensors) –

    one batch of data to use as input for the forward pass

Source code in laplace/utils/feature_extractor.py
def forward_with_features(
    self, x: torch.Tensor | MutableMapping[str, torch.Tensor | Any]
) -> tuple[torch.Tensor, torch.Tensor]:
    """Forward pass which returns the output of the penultimate layer along
    with the output of the last layer. If the last layer is not known yet,
    it will be determined when this function is called for the first time.

    Parameters
    ----------
    x : torch.Tensor or a dict-like object containing the input tensors
        one batch of data to use as input for the forward pass
    """
    out = self.forward(x)
    features = self._features[self._last_layer_name]

    if features.dim() > 2 and self.feature_reduction is not None:
        n_intermediate_dims = len(features.shape) - 2

        if self.feature_reduction == FeatureReduction.PICK_FIRST:
            features = features[
                (slice(None), *([0] * n_intermediate_dims), slice(None))
            ].squeeze()
        elif self.feature_reduction == FeatureReduction.PICK_LAST:
            features = features[
                (slice(None), *([-1] * n_intermediate_dims), slice(None))
            ].squeeze()
        else:
            ndim = features.ndim
            features = features.mean(
                dim=tuple(d for d in range(ndim) if d not in [0, ndim - 1])
            ).squeeze()

    return out, features

set_last_layer #

set_last_layer(last_layer_name: str) -> None

Set the last layer of the model by its name. This sets the forward hook to get the output of the penultimate layer.

Parameters:

  • last_layer_name (str) –

    the name of the last layer (fixed in model.named_modules()).

Source code in laplace/utils/feature_extractor.py
def set_last_layer(self, last_layer_name: str) -> None:
    """Set the last layer of the model by its name. This sets the forward
    hook to get the output of the penultimate layer.

    Parameters
    ----------
    last_layer_name : str
        the name of the last layer (fixed in `model.named_modules()`).
    """
    # set last_layer attributes and check if it is linear
    self._last_layer_name = last_layer_name
    self.last_layer = dict(self.model.named_modules())[last_layer_name]
    if not isinstance(self.last_layer, nn.Linear):
        raise ValueError("Use model with a linear last layer.")

    # set forward hook to extract features in future forward passes
    self.last_layer.register_forward_hook(self._get_hook(last_layer_name))

find_last_layer #

find_last_layer(x: Tensor | MutableMapping[str, Tensor | Any]) -> Tensor

Automatically determines the last layer of the model with one forward pass. It assumes that the last layer is the same for every forward pass and that it is an instance of torch.nn.Linear. Might not work with every architecture, but is tested with all PyTorch torchvision classification models (besides SqueezeNet, which has no linear last layer).

Parameters:

  • x (torch.Tensor or dict-like object containing the input tensors) –

    one batch of data to use as input for the forward pass

Source code in laplace/utils/feature_extractor.py
def find_last_layer(
    self, x: torch.Tensor | MutableMapping[str, torch.Tensor | Any]
) -> torch.Tensor:
    """Automatically determines the last layer of the model with one
    forward pass. It assumes that the last layer is the same for every
    forward pass and that it is an instance of `torch.nn.Linear`.
    Might not work with every architecture, but is tested with all PyTorch
    torchvision classification models (besides SqueezeNet, which has no
    linear last layer).

    Parameters
    ----------
    x : torch.Tensor or dict-like object containing the input tensors
        one batch of data to use as input for the forward pass
    """
    if self.last_layer is not None:
        raise ValueError("Last layer is already known.")

    act_out = dict()

    def get_act_hook(name):
        def act_hook(_, input, __):
            # only accepts one input (expects linear layer)
            try:
                act_out[name] = input[0].detach()
            except (IndexError, AttributeError):
                act_out[name] = None
            # remove hook
            handles[name].remove()

        return act_hook

    # set hooks for all modules
    handles = dict()
    for name, module in self.model.named_modules():
        handles[name] = module.register_forward_hook(get_act_hook(name))

    # check if model has more than one module
    # (there might be pathological exceptions)
    if len(handles) <= 2:
        raise ValueError("The model only has one module.")

    # forward pass to find execution order
    out = self.model(x)

    # find the last layer, store features, return output of forward pass
    keys = list(act_out.keys())
    for key in reversed(keys):
        layer = dict(self.model.named_modules())[key]
        if len(list(layer.children())) == 0:
            self.set_last_layer(key)

            # save features from first forward pass
            self._features[key] = act_out[key]

            return out

    raise ValueError("Something went wrong (all modules have children).")

Kron #

Kron(kfacs: list[tuple[Tensor] | Tensor])

Kronecker factored approximate curvature representation for a corresponding neural network. Each element in kfacs is either a tuple or single matrix. A tuple represents two Kronecker factors \(Q\), and \(H\) and a single element is just a full block Hessian approximation.

Parameters:

  • kfacs (list[Iterable[Tensor] | Tensor]) –

    each element in the list is a tuple of two Kronecker factors Q, H or a single matrix approximating the Hessian (in case of bias, for example)

Source code in laplace/utils/matrix.py
def __init__(self, kfacs: list[tuple[torch.Tensor] | torch.Tensor]) -> None:
    self.kfacs: list[tuple[torch.Tensor] | torch.Tensor] = kfacs

init_from_model #

init_from_model(model: Module | Iterable[Parameter], device: device) -> Kron

Initialize Kronecker factors based on a models architecture.

Parameters:

  • model (nn.Module or iterable of parameters, e.g. model.parameters()) –
  • device (device) –

Returns:

Source code in laplace/utils/matrix.py
@classmethod
def init_from_model(
    cls, model: nn.Module | Iterable[nn.Parameter], device: torch.device
) -> Kron:
    """Initialize Kronecker factors based on a models architecture.

    Parameters
    ----------
    model : nn.Module or iterable of parameters, e.g. model.parameters()
    device : torch.device

    Returns
    -------
    kron : Kron
    """
    if isinstance(model, torch.nn.Module):
        params = model.parameters()
    else:
        params = model

    kfacs = list()
    for p in params:
        if p.ndim == 1:  # bias
            P = p.size(0)
            kfacs.append([torch.zeros(P, P, device=device)])
        elif 4 >= p.ndim >= 2:  # fully connected or conv
            if p.ndim == 2:  # fully connected
                P_in, P_out = p.size()
            else:
                P_in, P_out = p.shape[0], np.prod(p.shape[1:])

            kfacs.append(
                [
                    torch.zeros(P_in, P_in, device=device),
                    torch.zeros(P_out, P_out, device=device),
                ]
            )
        else:
            raise ValueError("Invalid parameter shape in network.")
    return cls(kfacs)

__add__ #

__add__(other: Kron) -> Kron

Add up Kronecker factors self and other.

Parameters:

Returns:

Source code in laplace/utils/matrix.py
def __add__(self, other: Kron) -> Kron:
    """Add up Kronecker factors `self` and `other`.

    Parameters
    ----------
    other : Kron

    Returns
    -------
    kron : Kron
    """
    if not isinstance(other, Kron):
        raise ValueError("Can only add Kron to Kron.")

    kfacs = [
        [Hi.add(Hj) for Hi, Hj in zip(Fi, Fj)]
        for Fi, Fj in zip(self.kfacs, other.kfacs)
    ]

    return Kron(kfacs)

__mul__ #

__mul__(scalar: float | Tensor) -> Kron

Multiply all Kronecker factors by scalar. The multiplication is distributed across the number of factors using pow(scalar, 1 / len(F)). len(F) is either 1 or 2.

Parameters:

  • scalar ((float, Tensor)) –

Returns:

Source code in laplace/utils/matrix.py
def __mul__(self, scalar: float | torch.Tensor) -> Kron:
    """Multiply all Kronecker factors by scalar.
    The multiplication is distributed across the number of factors
    using `pow(scalar, 1 / len(F))`. `len(F)` is either `1` or `2`.

    Parameters
    ----------
    scalar : float, torch.Tensor

    Returns
    -------
    kron : Kron
    """
    if not _is_valid_scalar(scalar):
        raise ValueError("Input not valid python or torch scalar.")

    # distribute factors evenly so that each group is multiplied by factor
    kfacs = [[pow(scalar, 1 / len(F)) * Hi for Hi in F] for F in self.kfacs]
    return Kron(kfacs)

decompose #

decompose(damping: bool = False) -> KronDecomposed

Eigendecompose Kronecker factors and turn into KronDecomposed.

Parameters:

  • damping (bool, default: False ) –

    use damping

Returns:

Source code in laplace/utils/matrix.py
def decompose(self, damping: bool = False) -> KronDecomposed:
    """Eigendecompose Kronecker factors and turn into `KronDecomposed`.
    Parameters
    ----------
    damping : bool
        use damping

    Returns
    -------
    kron_decomposed : KronDecomposed
    """
    eigvecs, eigvals = list(), list()
    for F in self.kfacs:
        Qs, ls = list(), list()
        for Hi in F:
            if Hi.ndim > 1:
                # Dense Kronecker factor.
                eigval, Q = symeig(Hi)
            else:
                # Diagonal Kronecker factor.
                eigval = Hi
                # This might be too memory intensive since len(Hi) can be large.
                Q = torch.eye(len(Hi), dtype=Hi.dtype, device=Hi.device)
            Qs.append(Q)
            ls.append(eigval)
        eigvecs.append(Qs)
        eigvals.append(ls)
    return KronDecomposed(eigvecs, eigvals, damping=damping)

_bmm #

_bmm(W: Tensor) -> Tensor

Implementation of bmm which casts the parameters to the right shape.

Parameters:

  • W (Tensor) –

    matrix (batch, classes, params)

Returns:

  • SW ( Tensor ) –

    result (batch, classes, params)

Source code in laplace/utils/matrix.py
def _bmm(self, W: torch.Tensor) -> torch.Tensor:
    """Implementation of `bmm` which casts the parameters to the right shape.

    Parameters
    ----------
    W : torch.Tensor
        matrix `(batch, classes, params)`

    Returns
    -------
    SW : torch.Tensor
        result `(batch, classes, params)`
    """
    # self @ W[batch, k, params]
    assert len(W.size()) == 3
    B, K, P = W.size()
    W = W.reshape(B * K, P)
    cur_p = 0
    SW = list()
    for Fs in self.kfacs:
        if len(Fs) == 1:
            Q = Fs[0]
            p = len(Q)
            W_p = W[:, cur_p : cur_p + p].T
            SW.append((Q @ W_p).T if Q.ndim > 1 else (Q.view(-1, 1) * W_p).T)
            cur_p += p
        elif len(Fs) == 2:
            Q, H = Fs
            p_in, p_out = len(Q), len(H)
            p = p_in * p_out
            W_p = W[:, cur_p : cur_p + p].reshape(B * K, p_in, p_out)
            QW_p = Q @ W_p if Q.ndim > 1 else Q.view(-1, 1) * W_p
            QW_pHt = QW_p @ H.T if H.ndim > 1 else QW_p * H.view(1, -1)
            SW.append(QW_pHt.reshape(B * K, p_in * p_out))
            cur_p += p
        else:
            raise AttributeError("Shape mismatch")
    SW = torch.cat(SW, dim=1).reshape(B, K, P)
    return SW

bmm #

bmm(W: Tensor, exponent: float = 1) -> Tensor

Batched matrix multiplication with the Kronecker factors. If Kron is H, we compute H @ W. This is useful for computing the predictive or a regularization based on Kronecker factors as in continual learning.

Parameters:

  • W (Tensor) –

    matrix (batch, classes, params)

  • exponent (float, default: 1 ) –

    only can be 1 for Kron, requires KronDecomposed for other exponent values of the Kronecker factors.

Returns:

  • SW ( Tensor ) –

    result (batch, classes, params)

Source code in laplace/utils/matrix.py
def bmm(self, W: torch.Tensor, exponent: float = 1) -> torch.Tensor:
    """Batched matrix multiplication with the Kronecker factors.
    If Kron is `H`, we compute `H @ W`.
    This is useful for computing the predictive or a regularization
    based on Kronecker factors as in continual learning.

    Parameters
    ----------
    W : torch.Tensor
        matrix `(batch, classes, params)`
    exponent: float, default=1
        only can be `1` for Kron, requires `KronDecomposed` for other
        exponent values of the Kronecker factors.

    Returns
    -------
    SW : torch.Tensor
        result `(batch, classes, params)`
    """
    if exponent != 1:
        raise ValueError("Only supported after decomposition.")
    if W.ndim == 1:
        return self._bmm(W.unsqueeze(0).unsqueeze(0)).squeeze()
    elif W.ndim == 2:
        return self._bmm(W.unsqueeze(1)).squeeze()
    elif W.ndim == 3:
        return self._bmm(W)
    else:
        raise ValueError("Invalid shape for W")

logdet #

logdet() -> Tensor

Compute log determinant of the Kronecker factors and sums them up. This corresponds to the log determinant of the entire Hessian approximation.

Returns:

  • logdet ( Tensor ) –
Source code in laplace/utils/matrix.py
def logdet(self) -> torch.Tensor:
    """Compute log determinant of the Kronecker factors and sums them up.
    This corresponds to the log determinant of the entire Hessian approximation.

    Returns
    -------
    logdet : torch.Tensor
    """
    logdet = 0
    for F in self.kfacs:
        if len(F) == 1:
            logdet += F[0].logdet() if F[0].ndim > 1 else F[0].log().sum()
        else:  # len(F) == 2
            Hi, Hj = F
            p_in, p_out = len(Hi), len(Hj)
            logdet += p_out * Hi.logdet() if Hi.ndim > 1 else p_out * Hi.log().sum()
            logdet += p_in * Hj.logdet() if Hj.ndim > 1 else p_in * Hj.log().sum()
    return logdet

diag #

diag() -> Tensor

Extract diagonal of the entire Kronecker factorization.

Returns:

  • diag ( Tensor ) –
Source code in laplace/utils/matrix.py
def diag(self) -> torch.Tensor:
    """Extract diagonal of the entire Kronecker factorization.

    Returns
    -------
    diag : torch.Tensor
    """
    diags = list()
    for F in self.kfacs:
        F0 = F[0].diag() if F[0].ndim > 1 else F[0]
        if len(F) == 1:
            diags.append(F0)
        else:
            F1 = F[1].diag() if F[1].ndim > 1 else F[1]
            diags.append(torch.outer(F0, F1).flatten())
    return torch.cat(diags)

to_matrix #

to_matrix() -> Tensor

Make the Kronecker factorization dense by computing the kronecker product. Warning: this should only be used for testing purposes as it will allocate large amounts of memory for big architectures.

Returns:

  • block_diag ( Tensor ) –
Source code in laplace/utils/matrix.py
def to_matrix(self) -> torch.Tensor:
    """Make the Kronecker factorization dense by computing the kronecker product.
    Warning: this should only be used for testing purposes as it will allocate
    large amounts of memory for big architectures.

    Returns
    -------
    block_diag : torch.Tensor
    """
    blocks = list()
    for F in self.kfacs:
        F0 = F[0] if F[0].ndim > 1 else F[0].diag()
        if len(F) == 1:
            blocks.append(F0)
        else:
            F1 = F[1] if F[1].ndim > 1 else F[1].diag()
            blocks.append(kron(F0, F1))
    return block_diag(blocks)

KronDecomposed #

KronDecomposed(eigenvectors: list[tuple[Tensor]], eigenvalues: list[tuple[Tensor]], deltas: Tensor | None = None, damping: bool = False)

Decomposed Kronecker factored approximate curvature representation for a corresponding neural network. Each matrix in Kron is decomposed to obtain KronDecomposed. Front-loading decomposition allows cheap repeated computation of inverses and log determinants. In contrast to Kron, we can add scalar or layerwise scalars but we cannot add other Kron or KronDecomposed anymore.

Parameters:

  • eigenvectors (list[Tuple[Tensor]]) –

    eigenvectors corresponding to matrices in a corresponding Kron

  • eigenvalues (list[Tuple[Tensor]]) –

    eigenvalues corresponding to matrices in a corresponding Kron

  • deltas (Tensor, default: None ) –

    addend for each group of Kronecker factors representing, for example, a prior precision

  • dampen (bool, default: False ) –

    use dampen approximation mixing prior and Kron partially multiplicatively

Source code in laplace/utils/matrix.py
def __init__(
    self,
    eigenvectors: list[tuple[torch.Tensor]],
    eigenvalues: list[tuple[torch.Tensor]],
    deltas: torch.Tensor | None = None,
    damping: bool = False,
):
    self.eigenvectors: list[tuple[torch.Tensor]] = eigenvectors
    self.eigenvalues: list[tuple[torch.Tensor]] = eigenvalues
    device: torch.device = eigenvectors[0][0].device
    if deltas is None:
        self.deltas: torch.Tensor = torch.zeros(len(self), device=device)
    else:
        self._check_deltas(deltas)
        self.deltas: torch.Tensor = deltas
    self.damping: bool = damping

__add__ #

__add__(deltas: Tensor) -> KronDecomposed

Add scalar per layer or only scalar to Kronecker factors.

Parameters:

  • deltas (Tensor) –

    either same length as eigenvalues or scalar.

Returns:

Source code in laplace/utils/matrix.py
def __add__(self, deltas: torch.Tensor) -> KronDecomposed:
    """Add scalar per layer or only scalar to Kronecker factors.

    Parameters
    ----------
    deltas : torch.Tensor
        either same length as `eigenvalues` or scalar.

    Returns
    -------
    kron : KronDecomposed
    """
    self._check_deltas(deltas)
    return KronDecomposed(self.eigenvectors, self.eigenvalues, self.deltas + deltas)

__mul__ #

__mul__(scalar: Tensor | float) -> KronDecomposed

Multiply by a scalar by changing the eigenvalues. Same as for the case of Kron.

Parameters:

  • scalar (Tensor or float) –

Returns:

Source code in laplace/utils/matrix.py
def __mul__(self, scalar: torch.Tensor | float) -> KronDecomposed:
    """Multiply by a scalar by changing the eigenvalues.
    Same as for the case of `Kron`.

    Parameters
    ----------
    scalar : torch.Tensor or float

    Returns
    -------
    kron : KronDecomposed
    """
    if not _is_valid_scalar(scalar):
        raise ValueError("Invalid argument, can only multiply Kron with scalar.")

    eigenvalues = [
        [pow(scalar, 1 / len(ls)) * eigval for eigval in ls]
        for ls in self.eigenvalues
    ]
    return KronDecomposed(self.eigenvectors, eigenvalues, self.deltas)

logdet #

logdet() -> Tensor

Compute log determinant of the Kronecker factors and sums them up. This corresponds to the log determinant of the entire Hessian approximation. In contrast to Kron.logdet(), additive deltas corresponding to prior precisions are added.

Returns:

  • logdet ( Tensor ) –
Source code in laplace/utils/matrix.py
def logdet(self) -> torch.Tensor:
    """Compute log determinant of the Kronecker factors and sums them up.
    This corresponds to the log determinant of the entire Hessian approximation.
    In contrast to `Kron.logdet()`, additive `deltas` corresponding to prior
    precisions are added.

    Returns
    -------
    logdet : torch.Tensor
    """
    logdet = 0
    for ls, delta in zip(self.eigenvalues, self.deltas):
        if len(ls) == 1:  # not KFAC just full
            logdet += torch.log(ls[0] + delta).sum()
        elif len(ls) == 2:
            l1, l2 = ls
            if self.damping:
                l1d, l2d = l1 + torch.sqrt(delta), l2 + torch.sqrt(delta)
                logdet += torch.log(torch.outer(l1d, l2d)).sum()
            else:
                logdet += torch.log(torch.outer(l1, l2) + delta).sum()
        else:
            raise ValueError("Too many Kronecker factors. Something went wrong.")
    return logdet

_bmm #

_bmm(W: Tensor, exponent: float = -1) -> Tensor

Implementation of bmm, i.e., self ** exponent @ W.

Parameters:

  • W (Tensor) –

    matrix (batch, classes, params)

  • exponent (float, default: -1 ) –

    exponent on self

Returns:

  • SW ( Tensor ) –

    result (batch, classes, params)

Source code in laplace/utils/matrix.py
def _bmm(self, W: torch.Tensor, exponent: float = -1) -> torch.Tensor:
    """Implementation of `bmm`, i.e., `self ** exponent @ W`.

    Parameters
    ----------
    W : torch.Tensor
        matrix `(batch, classes, params)`
    exponent : float
        exponent on `self`

    Returns
    -------
    SW : torch.Tensor
        result `(batch, classes, params)`
    """
    # self @ W[batch, k, params]
    assert len(W.size()) == 3
    B, K, P = W.size()
    W = W.reshape(B * K, P)
    cur_p = 0
    SW = list()
    for i, (ls, Qs, delta) in enumerate(
        zip(self.eigenvalues, self.eigenvectors, self.deltas)
    ):
        if len(ls) == 1:
            Q, eigval, p = Qs[0], ls[0], len(ls[0])
            ldelta_exp = torch.pow(eigval + delta, exponent).reshape(-1, 1)
            W_p = W[:, cur_p : cur_p + p].T
            SW.append((Q @ (ldelta_exp * (Q.T @ W_p))).T)
            cur_p += p
        elif len(ls) == 2:
            Q1, Q2 = Qs
            l1, l2 = ls
            p = len(l1) * len(l2)
            if self.damping:
                l1d, l2d = l1 + torch.sqrt(delta), l2 + torch.sqrt(delta)
                ldelta_exp = torch.pow(torch.outer(l1d, l2d), exponent).unsqueeze(0)
            else:
                ldelta_exp = torch.pow(
                    torch.outer(l1, l2) + delta, exponent
                ).unsqueeze(0)
            p_in, p_out = len(l1), len(l2)
            W_p = W[:, cur_p : cur_p + p].reshape(B * K, p_in, p_out)
            W_p = (Q1.T @ W_p @ Q2) * ldelta_exp
            W_p = Q1 @ W_p @ Q2.T
            SW.append(W_p.reshape(B * K, p_in * p_out))
            cur_p += p
        else:
            raise AttributeError("Shape mismatch")
    SW = torch.cat(SW, dim=1).reshape(B, K, P)
    return SW

bmm #

bmm(W: Tensor, exponent: float = -1) -> Tensor

Batched matrix multiplication with the decomposed Kronecker factors. This is useful for computing the predictive or a regularization loss. Compared to Kron.bmm, a prior can be added here in form of deltas and the exponent can be other than just 1. Computes \(H^{exponent} W\).

Parameters:

  • W (Tensor) –

    matrix (batch, classes, params)

  • exponent (float, default: -1 ) –

Returns:

  • SW ( Tensor ) –

    result (batch, classes, params)

Source code in laplace/utils/matrix.py
def bmm(self, W: torch.Tensor, exponent: float = -1) -> torch.Tensor:
    """Batched matrix multiplication with the decomposed Kronecker factors.
    This is useful for computing the predictive or a regularization loss.
    Compared to `Kron.bmm`, a prior can be added here in form of `deltas`
    and the exponent can be other than just 1.
    Computes \\(H^{exponent} W\\).

    Parameters
    ----------
    W : torch.Tensor
        matrix `(batch, classes, params)`
    exponent: float, default=1

    Returns
    -------
    SW : torch.Tensor
        result `(batch, classes, params)`
    """
    if W.ndim == 1:
        return self._bmm(W.unsqueeze(0).unsqueeze(0), exponent).squeeze()
    elif W.ndim == 2:
        return self._bmm(W.unsqueeze(1), exponent).squeeze()
    elif W.ndim == 3:
        return self._bmm(W, exponent)
    else:
        raise ValueError("Invalid shape for W")

diag #

diag(exponent: float = 1) -> Tensor

Extract diagonal of the entire decomposed Kronecker factorization.

Parameters:

  • exponent (float, default: 1 ) –

    exponent of the Kronecker factorization

Returns:

  • diag ( Tensor ) –
Source code in laplace/utils/matrix.py
def diag(self, exponent: float = 1) -> torch.Tensor:
    """Extract diagonal of the entire decomposed Kronecker factorization.

    Parameters
    ----------
    exponent: float, default=1
        exponent of the Kronecker factorization

    Returns
    -------
    diag : torch.Tensor
    """
    diags = list()
    for Qs, ls, delta in zip(self.eigenvectors, self.eigenvalues, self.deltas):
        if len(ls) == 1:
            Ql = Qs[0] * torch.pow(ls[0] + delta, exponent).reshape(1, -1)
            d = torch.einsum(
                "mp,mp->m", Ql, Qs[0]
            )  # only compute inner products for diag
            diags.append(d)
        else:
            Q1, Q2 = Qs
            l1, l2 = ls
            if self.damping:
                delta_sqrt = torch.sqrt(delta)
                eigval = torch.pow(
                    torch.outer(l1 + delta_sqrt, l2 + delta_sqrt), exponent
                )
            else:
                eigval = torch.pow(torch.outer(l1, l2) + delta, exponent)
            d = oe.contract("mp,nq,pq,mp,nq->mn", Q1, Q2, eigval, Q1, Q2).flatten()
            diags.append(d)
    return torch.cat(diags)

to_matrix #

to_matrix(exponent: float = 1) -> Tensor

Make the Kronecker factorization dense by computing the kronecker product. Warning: this should only be used for testing purposes as it will allocate large amounts of memory for big architectures.

Parameters:

  • exponent (float, default: 1 ) –

    exponent of the Kronecker factorization

Returns:

  • block_diag ( Tensor ) –
Source code in laplace/utils/matrix.py
def to_matrix(self, exponent: float = 1) -> torch.Tensor:
    """Make the Kronecker factorization dense by computing the kronecker product.
    Warning: this should only be used for testing purposes as it will allocate
    large amounts of memory for big architectures.

    Parameters
    ----------
    exponent: float, default=1
        exponent of the Kronecker factorization

    Returns
    -------
    block_diag : torch.Tensor
    """
    blocks = list()
    for Qs, ls, delta in zip(self.eigenvectors, self.eigenvalues, self.deltas):
        if len(ls) == 1:
            Q, eigval = Qs[0], ls[0]
            blocks.append(Q @ torch.diag(torch.pow(eigval + delta, exponent)) @ Q.T)
        else:
            Q1, Q2 = Qs
            l1, l2 = ls
            Q = kron(Q1, Q2)
            if self.damping:
                delta_sqrt = torch.sqrt(delta)
                eigval = torch.pow(
                    torch.outer(l1 + delta_sqrt, l2 + delta_sqrt), exponent
                )
            else:
                eigval = torch.pow(torch.outer(l1, l2) + delta, exponent)
            L = torch.diag(eigval.flatten())
            blocks.append(Q @ L @ Q.T)
    return block_diag(blocks)

SubnetMask #

SubnetMask(model: Module)

Baseclass for all subnetwork masks in this library (for subnetwork Laplace).

Parameters:

  • model (Module) –
Source code in laplace/utils/subnetmask.py
def __init__(self, model: nn.Module) -> None:
    self.model: nn.Module = model
    self.parameter_vector: torch.Tensor = parameters_to_vector(
        self.model.parameters()
    ).detach()
    self._n_params: int = len(self.parameter_vector)
    self._indices: torch.LongTensor | None = None
    self._n_params_subnet: int | None = None

convert_subnet_mask_to_indices #

convert_subnet_mask_to_indices(subnet_mask: Tensor) -> LongTensor

Converts a subnetwork mask into subnetwork indices.

Parameters:

  • subnet_mask (Tensor) –

    a binary vector of size (n_params) where 1s locate the subnetwork parameters within the vectorized model parameters (i.e. torch.nn.utils.parameters_to_vector(model.parameters()))

Returns:

  • subnet_mask_indices ( LongTensor ) –

    a vector of indices of the vectorized model parameters (i.e. torch.nn.utils.parameters_to_vector(model.parameters())) that define the subnetwork

Source code in laplace/utils/subnetmask.py
def convert_subnet_mask_to_indices(
    self, subnet_mask: torch.Tensor
) -> torch.LongTensor:
    """Converts a subnetwork mask into subnetwork indices.

    Parameters
    ----------
    subnet_mask : torch.Tensor
        a binary vector of size (n_params) where 1s locate the subnetwork parameters
        within the vectorized model parameters
        (i.e. `torch.nn.utils.parameters_to_vector(model.parameters())`)

    Returns
    -------
    subnet_mask_indices : torch.LongTensor
        a vector of indices of the vectorized model parameters
        (i.e. `torch.nn.utils.parameters_to_vector(model.parameters())`)
        that define the subnetwork
    """
    if not isinstance(subnet_mask, torch.Tensor):
        raise ValueError("Subnetwork mask needs to be torch.Tensor!")
    elif (
        subnet_mask.dtype
        not in [
            torch.int64,
            torch.int32,
            torch.int16,
            torch.int8,
            torch.uint8,
            torch.bool,
        ]
        or len(subnet_mask.shape) != 1
    ):
        raise ValueError(
            "Subnetwork mask needs to be 1-dimensional integral or boolean tensor!"
        )
    elif (
        len(subnet_mask) != self._n_params
        or len(subnet_mask[subnet_mask == 0]) + len(subnet_mask[subnet_mask == 1])
        != self._n_params
    ):
        raise ValueError(
            "Subnetwork mask needs to be a binary vector of"
            "size (n_params) where 1s locate the subnetwork"
            "parameters within the vectorized model parameters"
            "(i.e. `torch.nn.utils.parameters_to_vector(model.parameters())`)!"
        )

    subnet_mask_indices = subnet_mask.nonzero(as_tuple=True)[0]
    return subnet_mask_indices

select #

select(train_loader: DataLoader | None = None) -> LongTensor

Select the subnetwork mask.

Parameters:

  • train_loader (DataLoader, default: None ) –

    each iterate is a training batch (X, y); train_loader.dataset needs to be set to access \(N\), size of the data set

Returns:

  • subnet_mask_indices ( LongTensor ) –

    a vector of indices of the vectorized model parameters (i.e. torch.nn.utils.parameters_to_vector(model.parameters())) that define the subnetwork

Source code in laplace/utils/subnetmask.py
def select(self, train_loader: DataLoader | None = None) -> torch.LongTensor:
    """Select the subnetwork mask.

    Parameters
    ----------
    train_loader : torch.data.utils.DataLoader, default=None
        each iterate is a training batch (X, y);
        `train_loader.dataset` needs to be set to access \\(N\\), size of the data set

    Returns
    -------
    subnet_mask_indices : torch.LongTensor
        a vector of indices of the vectorized model parameters
        (i.e. `torch.nn.utils.parameters_to_vector(model.parameters())`)
        that define the subnetwork
    """
    if self._indices is not None:
        raise ValueError("Subnetwork mask already selected.")

    subnet_mask = self.get_subnet_mask(train_loader)
    self._indices = self.convert_subnet_mask_to_indices(subnet_mask)
    return self._indices

get_subnet_mask #

get_subnet_mask(train_loader: DataLoader) -> Tensor

Get the subnetwork mask.

Parameters:

  • train_loader (DataLoader) –

    each iterate is a training batch (X, y); train_loader.dataset needs to be set to access \(N\), size of the data set

Returns:

  • subnet_mask ( Tensor ) –

    a binary vector of size (n_params) where 1s locate the subnetwork parameters within the vectorized model parameters (i.e. torch.nn.utils.parameters_to_vector(model.parameters()))

Source code in laplace/utils/subnetmask.py
def get_subnet_mask(self, train_loader: DataLoader) -> torch.Tensor:
    """Get the subnetwork mask.

    Parameters
    ----------
    train_loader : torch.data.utils.DataLoader
        each iterate is a training batch (X, y);
        `train_loader.dataset` needs to be set to access \\(N\\), size of the data set

    Returns
    -------
    subnet_mask: torch.Tensor
        a binary vector of size (n_params) where 1s locate the subnetwork parameters
        within the vectorized model parameters
        (i.e. `torch.nn.utils.parameters_to_vector(model.parameters())`)
    """
    raise NotImplementedError

RandomSubnetMask #

RandomSubnetMask(model: Module, n_params_subnet: int)

Bases: ScoreBasedSubnetMask

Subnetwork mask of parameters sampled uniformly at random.

Source code in laplace/utils/subnetmask.py
def __init__(self, model: nn.Module, n_params_subnet: int) -> None:
    super().__init__(model)

    if n_params_subnet is None:
        raise ValueError(
            "Need to pass number of subnetwork parameters when using subnetwork Laplace."
        )
    if n_params_subnet > self._n_params:
        raise ValueError(
            f"Subnetwork ({n_params_subnet}) cannot be larger than model ({self._n_params})."
        )
    self._n_params_subnet = n_params_subnet
    self._param_scores: torch.Tensor | None = None

convert_subnet_mask_to_indices #

convert_subnet_mask_to_indices(subnet_mask: Tensor) -> LongTensor

Converts a subnetwork mask into subnetwork indices.

Parameters:

  • subnet_mask (Tensor) –

    a binary vector of size (n_params) where 1s locate the subnetwork parameters within the vectorized model parameters (i.e. torch.nn.utils.parameters_to_vector(model.parameters()))

Returns:

  • subnet_mask_indices ( LongTensor ) –

    a vector of indices of the vectorized model parameters (i.e. torch.nn.utils.parameters_to_vector(model.parameters())) that define the subnetwork

Source code in laplace/utils/subnetmask.py
def convert_subnet_mask_to_indices(
    self, subnet_mask: torch.Tensor
) -> torch.LongTensor:
    """Converts a subnetwork mask into subnetwork indices.

    Parameters
    ----------
    subnet_mask : torch.Tensor
        a binary vector of size (n_params) where 1s locate the subnetwork parameters
        within the vectorized model parameters
        (i.e. `torch.nn.utils.parameters_to_vector(model.parameters())`)

    Returns
    -------
    subnet_mask_indices : torch.LongTensor
        a vector of indices of the vectorized model parameters
        (i.e. `torch.nn.utils.parameters_to_vector(model.parameters())`)
        that define the subnetwork
    """
    if not isinstance(subnet_mask, torch.Tensor):
        raise ValueError("Subnetwork mask needs to be torch.Tensor!")
    elif (
        subnet_mask.dtype
        not in [
            torch.int64,
            torch.int32,
            torch.int16,
            torch.int8,
            torch.uint8,
            torch.bool,
        ]
        or len(subnet_mask.shape) != 1
    ):
        raise ValueError(
            "Subnetwork mask needs to be 1-dimensional integral or boolean tensor!"
        )
    elif (
        len(subnet_mask) != self._n_params
        or len(subnet_mask[subnet_mask == 0]) + len(subnet_mask[subnet_mask == 1])
        != self._n_params
    ):
        raise ValueError(
            "Subnetwork mask needs to be a binary vector of"
            "size (n_params) where 1s locate the subnetwork"
            "parameters within the vectorized model parameters"
            "(i.e. `torch.nn.utils.parameters_to_vector(model.parameters())`)!"
        )

    subnet_mask_indices = subnet_mask.nonzero(as_tuple=True)[0]
    return subnet_mask_indices

select #

select(train_loader: DataLoader | None = None) -> LongTensor

Select the subnetwork mask.

Parameters:

  • train_loader (DataLoader, default: None ) –

    each iterate is a training batch (X, y); train_loader.dataset needs to be set to access \(N\), size of the data set

Returns:

  • subnet_mask_indices ( LongTensor ) –

    a vector of indices of the vectorized model parameters (i.e. torch.nn.utils.parameters_to_vector(model.parameters())) that define the subnetwork

Source code in laplace/utils/subnetmask.py
def select(self, train_loader: DataLoader | None = None) -> torch.LongTensor:
    """Select the subnetwork mask.

    Parameters
    ----------
    train_loader : torch.data.utils.DataLoader, default=None
        each iterate is a training batch (X, y);
        `train_loader.dataset` needs to be set to access \\(N\\), size of the data set

    Returns
    -------
    subnet_mask_indices : torch.LongTensor
        a vector of indices of the vectorized model parameters
        (i.e. `torch.nn.utils.parameters_to_vector(model.parameters())`)
        that define the subnetwork
    """
    if self._indices is not None:
        raise ValueError("Subnetwork mask already selected.")

    subnet_mask = self.get_subnet_mask(train_loader)
    self._indices = self.convert_subnet_mask_to_indices(subnet_mask)
    return self._indices

get_subnet_mask #

get_subnet_mask(train_loader)

Get the subnetwork mask by (descendingly) ranking parameters based on their scores.

Source code in laplace/utils/subnetmask.py
def get_subnet_mask(self, train_loader):
    """Get the subnetwork mask by (descendingly) ranking parameters based on their scores."""
    if self._param_scores is None:
        self._param_scores = self.compute_param_scores(train_loader)
    self._check_param_scores()

    idx = torch.argsort(self._param_scores, descending=True)[
        : self._n_params_subnet
    ]
    idx = idx.sort()[0]
    subnet_mask = torch.zeros_like(self.parameter_vector).bool()
    subnet_mask[idx] = 1
    return subnet_mask

LargestMagnitudeSubnetMask #

LargestMagnitudeSubnetMask(model: Module, n_params_subnet: int)

Bases: ScoreBasedSubnetMask

Subnetwork mask identifying the parameters with the largest magnitude.

Source code in laplace/utils/subnetmask.py
def __init__(self, model: nn.Module, n_params_subnet: int) -> None:
    super().__init__(model)

    if n_params_subnet is None:
        raise ValueError(
            "Need to pass number of subnetwork parameters when using subnetwork Laplace."
        )
    if n_params_subnet > self._n_params:
        raise ValueError(
            f"Subnetwork ({n_params_subnet}) cannot be larger than model ({self._n_params})."
        )
    self._n_params_subnet = n_params_subnet
    self._param_scores: torch.Tensor | None = None

convert_subnet_mask_to_indices #

convert_subnet_mask_to_indices(subnet_mask: Tensor) -> LongTensor

Converts a subnetwork mask into subnetwork indices.

Parameters:

  • subnet_mask (Tensor) –

    a binary vector of size (n_params) where 1s locate the subnetwork parameters within the vectorized model parameters (i.e. torch.nn.utils.parameters_to_vector(model.parameters()))

Returns:

  • subnet_mask_indices ( LongTensor ) –

    a vector of indices of the vectorized model parameters (i.e. torch.nn.utils.parameters_to_vector(model.parameters())) that define the subnetwork

Source code in laplace/utils/subnetmask.py
def convert_subnet_mask_to_indices(
    self, subnet_mask: torch.Tensor
) -> torch.LongTensor:
    """Converts a subnetwork mask into subnetwork indices.

    Parameters
    ----------
    subnet_mask : torch.Tensor
        a binary vector of size (n_params) where 1s locate the subnetwork parameters
        within the vectorized model parameters
        (i.e. `torch.nn.utils.parameters_to_vector(model.parameters())`)

    Returns
    -------
    subnet_mask_indices : torch.LongTensor
        a vector of indices of the vectorized model parameters
        (i.e. `torch.nn.utils.parameters_to_vector(model.parameters())`)
        that define the subnetwork
    """
    if not isinstance(subnet_mask, torch.Tensor):
        raise ValueError("Subnetwork mask needs to be torch.Tensor!")
    elif (
        subnet_mask.dtype
        not in [
            torch.int64,
            torch.int32,
            torch.int16,
            torch.int8,
            torch.uint8,
            torch.bool,
        ]
        or len(subnet_mask.shape) != 1
    ):
        raise ValueError(
            "Subnetwork mask needs to be 1-dimensional integral or boolean tensor!"
        )
    elif (
        len(subnet_mask) != self._n_params
        or len(subnet_mask[subnet_mask == 0]) + len(subnet_mask[subnet_mask == 1])
        != self._n_params
    ):
        raise ValueError(
            "Subnetwork mask needs to be a binary vector of"
            "size (n_params) where 1s locate the subnetwork"
            "parameters within the vectorized model parameters"
            "(i.e. `torch.nn.utils.parameters_to_vector(model.parameters())`)!"
        )

    subnet_mask_indices = subnet_mask.nonzero(as_tuple=True)[0]
    return subnet_mask_indices

select #

select(train_loader: DataLoader | None = None) -> LongTensor

Select the subnetwork mask.

Parameters:

  • train_loader (DataLoader, default: None ) –

    each iterate is a training batch (X, y); train_loader.dataset needs to be set to access \(N\), size of the data set

Returns:

  • subnet_mask_indices ( LongTensor ) –

    a vector of indices of the vectorized model parameters (i.e. torch.nn.utils.parameters_to_vector(model.parameters())) that define the subnetwork

Source code in laplace/utils/subnetmask.py
def select(self, train_loader: DataLoader | None = None) -> torch.LongTensor:
    """Select the subnetwork mask.

    Parameters
    ----------
    train_loader : torch.data.utils.DataLoader, default=None
        each iterate is a training batch (X, y);
        `train_loader.dataset` needs to be set to access \\(N\\), size of the data set

    Returns
    -------
    subnet_mask_indices : torch.LongTensor
        a vector of indices of the vectorized model parameters
        (i.e. `torch.nn.utils.parameters_to_vector(model.parameters())`)
        that define the subnetwork
    """
    if self._indices is not None:
        raise ValueError("Subnetwork mask already selected.")

    subnet_mask = self.get_subnet_mask(train_loader)
    self._indices = self.convert_subnet_mask_to_indices(subnet_mask)
    return self._indices

get_subnet_mask #

get_subnet_mask(train_loader)

Get the subnetwork mask by (descendingly) ranking parameters based on their scores.

Source code in laplace/utils/subnetmask.py
def get_subnet_mask(self, train_loader):
    """Get the subnetwork mask by (descendingly) ranking parameters based on their scores."""
    if self._param_scores is None:
        self._param_scores = self.compute_param_scores(train_loader)
    self._check_param_scores()

    idx = torch.argsort(self._param_scores, descending=True)[
        : self._n_params_subnet
    ]
    idx = idx.sort()[0]
    subnet_mask = torch.zeros_like(self.parameter_vector).bool()
    subnet_mask[idx] = 1
    return subnet_mask

LargestVarianceDiagLaplaceSubnetMask #

LargestVarianceDiagLaplaceSubnetMask(model: Module, n_params_subnet: int, diag_laplace_model: DiagLaplace)

Bases: ScoreBasedSubnetMask

Subnetwork mask identifying the parameters with the largest marginal variances (estimated using a diagonal Laplace approximation over all model parameters).

Parameters:

  • model (Module) –
  • n_params_subnet (int) –

    number of parameters in the subnetwork (i.e. number of top-scoring parameters to select)

  • diag_laplace_model (`laplace.baselaplace.DiagLaplace`) –

    diagonal Laplace model to use for variance estimation

Source code in laplace/utils/subnetmask.py
def __init__(
    self,
    model: nn.Module,
    n_params_subnet: int,
    diag_laplace_model: laplace.baselaplace.DiagLaplace,
):
    super().__init__(model, n_params_subnet)
    self.diag_laplace_model: laplace.baselaplace.DiagLaplace = diag_laplace_model

convert_subnet_mask_to_indices #

convert_subnet_mask_to_indices(subnet_mask: Tensor) -> LongTensor

Converts a subnetwork mask into subnetwork indices.

Parameters:

  • subnet_mask (Tensor) –

    a binary vector of size (n_params) where 1s locate the subnetwork parameters within the vectorized model parameters (i.e. torch.nn.utils.parameters_to_vector(model.parameters()))

Returns:

  • subnet_mask_indices ( LongTensor ) –

    a vector of indices of the vectorized model parameters (i.e. torch.nn.utils.parameters_to_vector(model.parameters())) that define the subnetwork

Source code in laplace/utils/subnetmask.py
def convert_subnet_mask_to_indices(
    self, subnet_mask: torch.Tensor
) -> torch.LongTensor:
    """Converts a subnetwork mask into subnetwork indices.

    Parameters
    ----------
    subnet_mask : torch.Tensor
        a binary vector of size (n_params) where 1s locate the subnetwork parameters
        within the vectorized model parameters
        (i.e. `torch.nn.utils.parameters_to_vector(model.parameters())`)

    Returns
    -------
    subnet_mask_indices : torch.LongTensor
        a vector of indices of the vectorized model parameters
        (i.e. `torch.nn.utils.parameters_to_vector(model.parameters())`)
        that define the subnetwork
    """
    if not isinstance(subnet_mask, torch.Tensor):
        raise ValueError("Subnetwork mask needs to be torch.Tensor!")
    elif (
        subnet_mask.dtype
        not in [
            torch.int64,
            torch.int32,
            torch.int16,
            torch.int8,
            torch.uint8,
            torch.bool,
        ]
        or len(subnet_mask.shape) != 1
    ):
        raise ValueError(
            "Subnetwork mask needs to be 1-dimensional integral or boolean tensor!"
        )
    elif (
        len(subnet_mask) != self._n_params
        or len(subnet_mask[subnet_mask == 0]) + len(subnet_mask[subnet_mask == 1])
        != self._n_params
    ):
        raise ValueError(
            "Subnetwork mask needs to be a binary vector of"
            "size (n_params) where 1s locate the subnetwork"
            "parameters within the vectorized model parameters"
            "(i.e. `torch.nn.utils.parameters_to_vector(model.parameters())`)!"
        )

    subnet_mask_indices = subnet_mask.nonzero(as_tuple=True)[0]
    return subnet_mask_indices

select #

select(train_loader: DataLoader | None = None) -> LongTensor

Select the subnetwork mask.

Parameters:

  • train_loader (DataLoader, default: None ) –

    each iterate is a training batch (X, y); train_loader.dataset needs to be set to access \(N\), size of the data set

Returns:

  • subnet_mask_indices ( LongTensor ) –

    a vector of indices of the vectorized model parameters (i.e. torch.nn.utils.parameters_to_vector(model.parameters())) that define the subnetwork

Source code in laplace/utils/subnetmask.py
def select(self, train_loader: DataLoader | None = None) -> torch.LongTensor:
    """Select the subnetwork mask.

    Parameters
    ----------
    train_loader : torch.data.utils.DataLoader, default=None
        each iterate is a training batch (X, y);
        `train_loader.dataset` needs to be set to access \\(N\\), size of the data set

    Returns
    -------
    subnet_mask_indices : torch.LongTensor
        a vector of indices of the vectorized model parameters
        (i.e. `torch.nn.utils.parameters_to_vector(model.parameters())`)
        that define the subnetwork
    """
    if self._indices is not None:
        raise ValueError("Subnetwork mask already selected.")

    subnet_mask = self.get_subnet_mask(train_loader)
    self._indices = self.convert_subnet_mask_to_indices(subnet_mask)
    return self._indices

get_subnet_mask #

get_subnet_mask(train_loader)

Get the subnetwork mask by (descendingly) ranking parameters based on their scores.

Source code in laplace/utils/subnetmask.py
def get_subnet_mask(self, train_loader):
    """Get the subnetwork mask by (descendingly) ranking parameters based on their scores."""
    if self._param_scores is None:
        self._param_scores = self.compute_param_scores(train_loader)
    self._check_param_scores()

    idx = torch.argsort(self._param_scores, descending=True)[
        : self._n_params_subnet
    ]
    idx = idx.sort()[0]
    subnet_mask = torch.zeros_like(self.parameter_vector).bool()
    subnet_mask[idx] = 1
    return subnet_mask

LargestVarianceSWAGSubnetMask #

LargestVarianceSWAGSubnetMask(model: Module, n_params_subnet: int, likelihood: Likelihood | str = Likelihood.CLASSIFICATION, swag_n_snapshots: int = 40, swag_snapshot_freq: int = 1, swag_lr: float = 0.01)

Bases: ScoreBasedSubnetMask

Subnetwork mask identifying the parameters with the largest marginal variances (estimated using diagonal SWAG over all model parameters).

Parameters:

  • model (Module) –
  • n_params_subnet (int) –

    number of parameters in the subnetwork (i.e. number of top-scoring parameters to select)

  • likelihood (str, default: CLASSIFICATION ) –

    'classification' or 'regression'

  • swag_n_snapshots (int, default: 40 ) –

    number of model snapshots to collect for SWAG

  • swag_snapshot_freq (int, default: 1 ) –

    SWAG snapshot collection frequency (in epochs)

  • swag_lr (float, default: 0.01 ) –

    learning rate for SWAG snapshot collection

Source code in laplace/utils/subnetmask.py
def __init__(
    self,
    model: nn.Module,
    n_params_subnet: int,
    likelihood: Likelihood | str = Likelihood.CLASSIFICATION,
    swag_n_snapshots: int = 40,
    swag_snapshot_freq: int = 1,
    swag_lr: float = 0.01,
):
    if likelihood not in [Likelihood.CLASSIFICATION, Likelihood.REGRESSION]:
        raise ValueError("Only available for classification and regression!")

    super().__init__(model, n_params_subnet)

    self.likelihood: Likelihood | str = likelihood
    self.swag_n_snapshots: int = swag_n_snapshots
    self.swag_snapshot_freq: int = swag_snapshot_freq
    self.swag_lr: float = swag_lr

convert_subnet_mask_to_indices #

convert_subnet_mask_to_indices(subnet_mask: Tensor) -> LongTensor

Converts a subnetwork mask into subnetwork indices.

Parameters:

  • subnet_mask (Tensor) –

    a binary vector of size (n_params) where 1s locate the subnetwork parameters within the vectorized model parameters (i.e. torch.nn.utils.parameters_to_vector(model.parameters()))

Returns:

  • subnet_mask_indices ( LongTensor ) –

    a vector of indices of the vectorized model parameters (i.e. torch.nn.utils.parameters_to_vector(model.parameters())) that define the subnetwork

Source code in laplace/utils/subnetmask.py
def convert_subnet_mask_to_indices(
    self, subnet_mask: torch.Tensor
) -> torch.LongTensor:
    """Converts a subnetwork mask into subnetwork indices.

    Parameters
    ----------
    subnet_mask : torch.Tensor
        a binary vector of size (n_params) where 1s locate the subnetwork parameters
        within the vectorized model parameters
        (i.e. `torch.nn.utils.parameters_to_vector(model.parameters())`)

    Returns
    -------
    subnet_mask_indices : torch.LongTensor
        a vector of indices of the vectorized model parameters
        (i.e. `torch.nn.utils.parameters_to_vector(model.parameters())`)
        that define the subnetwork
    """
    if not isinstance(subnet_mask, torch.Tensor):
        raise ValueError("Subnetwork mask needs to be torch.Tensor!")
    elif (
        subnet_mask.dtype
        not in [
            torch.int64,
            torch.int32,
            torch.int16,
            torch.int8,
            torch.uint8,
            torch.bool,
        ]
        or len(subnet_mask.shape) != 1
    ):
        raise ValueError(
            "Subnetwork mask needs to be 1-dimensional integral or boolean tensor!"
        )
    elif (
        len(subnet_mask) != self._n_params
        or len(subnet_mask[subnet_mask == 0]) + len(subnet_mask[subnet_mask == 1])
        != self._n_params
    ):
        raise ValueError(
            "Subnetwork mask needs to be a binary vector of"
            "size (n_params) where 1s locate the subnetwork"
            "parameters within the vectorized model parameters"
            "(i.e. `torch.nn.utils.parameters_to_vector(model.parameters())`)!"
        )

    subnet_mask_indices = subnet_mask.nonzero(as_tuple=True)[0]
    return subnet_mask_indices

select #

select(train_loader: DataLoader | None = None) -> LongTensor

Select the subnetwork mask.

Parameters:

  • train_loader (DataLoader, default: None ) –

    each iterate is a training batch (X, y); train_loader.dataset needs to be set to access \(N\), size of the data set

Returns:

  • subnet_mask_indices ( LongTensor ) –

    a vector of indices of the vectorized model parameters (i.e. torch.nn.utils.parameters_to_vector(model.parameters())) that define the subnetwork

Source code in laplace/utils/subnetmask.py
def select(self, train_loader: DataLoader | None = None) -> torch.LongTensor:
    """Select the subnetwork mask.

    Parameters
    ----------
    train_loader : torch.data.utils.DataLoader, default=None
        each iterate is a training batch (X, y);
        `train_loader.dataset` needs to be set to access \\(N\\), size of the data set

    Returns
    -------
    subnet_mask_indices : torch.LongTensor
        a vector of indices of the vectorized model parameters
        (i.e. `torch.nn.utils.parameters_to_vector(model.parameters())`)
        that define the subnetwork
    """
    if self._indices is not None:
        raise ValueError("Subnetwork mask already selected.")

    subnet_mask = self.get_subnet_mask(train_loader)
    self._indices = self.convert_subnet_mask_to_indices(subnet_mask)
    return self._indices

get_subnet_mask #

get_subnet_mask(train_loader)

Get the subnetwork mask by (descendingly) ranking parameters based on their scores.

Source code in laplace/utils/subnetmask.py
def get_subnet_mask(self, train_loader):
    """Get the subnetwork mask by (descendingly) ranking parameters based on their scores."""
    if self._param_scores is None:
        self._param_scores = self.compute_param_scores(train_loader)
    self._check_param_scores()

    idx = torch.argsort(self._param_scores, descending=True)[
        : self._n_params_subnet
    ]
    idx = idx.sort()[0]
    subnet_mask = torch.zeros_like(self.parameter_vector).bool()
    subnet_mask[idx] = 1
    return subnet_mask

ParamNameSubnetMask #

ParamNameSubnetMask(model: Module, parameter_names: list[str])

Bases: SubnetMask

Subnetwork mask corresponding to the specified parameters of the neural network.

Parameters:

  • model (Module) –
  • parameter_names (list[str]) –

    list of names of the parameters (as in model.named_parameters()) that define the subnetwork

Source code in laplace/utils/subnetmask.py
def __init__(self, model: nn.Module, parameter_names: list[str]) -> None:
    super().__init__(model)
    self._parameter_names: list[str] = parameter_names
    self._n_params_subnet: int | None = None

convert_subnet_mask_to_indices #

convert_subnet_mask_to_indices(subnet_mask: Tensor) -> LongTensor

Converts a subnetwork mask into subnetwork indices.

Parameters:

  • subnet_mask (Tensor) –

    a binary vector of size (n_params) where 1s locate the subnetwork parameters within the vectorized model parameters (i.e. torch.nn.utils.parameters_to_vector(model.parameters()))

Returns:

  • subnet_mask_indices ( LongTensor ) –

    a vector of indices of the vectorized model parameters (i.e. torch.nn.utils.parameters_to_vector(model.parameters())) that define the subnetwork

Source code in laplace/utils/subnetmask.py
def convert_subnet_mask_to_indices(
    self, subnet_mask: torch.Tensor
) -> torch.LongTensor:
    """Converts a subnetwork mask into subnetwork indices.

    Parameters
    ----------
    subnet_mask : torch.Tensor
        a binary vector of size (n_params) where 1s locate the subnetwork parameters
        within the vectorized model parameters
        (i.e. `torch.nn.utils.parameters_to_vector(model.parameters())`)

    Returns
    -------
    subnet_mask_indices : torch.LongTensor
        a vector of indices of the vectorized model parameters
        (i.e. `torch.nn.utils.parameters_to_vector(model.parameters())`)
        that define the subnetwork
    """
    if not isinstance(subnet_mask, torch.Tensor):
        raise ValueError("Subnetwork mask needs to be torch.Tensor!")
    elif (
        subnet_mask.dtype
        not in [
            torch.int64,
            torch.int32,
            torch.int16,
            torch.int8,
            torch.uint8,
            torch.bool,
        ]
        or len(subnet_mask.shape) != 1
    ):
        raise ValueError(
            "Subnetwork mask needs to be 1-dimensional integral or boolean tensor!"
        )
    elif (
        len(subnet_mask) != self._n_params
        or len(subnet_mask[subnet_mask == 0]) + len(subnet_mask[subnet_mask == 1])
        != self._n_params
    ):
        raise ValueError(
            "Subnetwork mask needs to be a binary vector of"
            "size (n_params) where 1s locate the subnetwork"
            "parameters within the vectorized model parameters"
            "(i.e. `torch.nn.utils.parameters_to_vector(model.parameters())`)!"
        )

    subnet_mask_indices = subnet_mask.nonzero(as_tuple=True)[0]
    return subnet_mask_indices

select #

select(train_loader: DataLoader | None = None) -> LongTensor

Select the subnetwork mask.

Parameters:

  • train_loader (DataLoader, default: None ) –

    each iterate is a training batch (X, y); train_loader.dataset needs to be set to access \(N\), size of the data set

Returns:

  • subnet_mask_indices ( LongTensor ) –

    a vector of indices of the vectorized model parameters (i.e. torch.nn.utils.parameters_to_vector(model.parameters())) that define the subnetwork

Source code in laplace/utils/subnetmask.py
def select(self, train_loader: DataLoader | None = None) -> torch.LongTensor:
    """Select the subnetwork mask.

    Parameters
    ----------
    train_loader : torch.data.utils.DataLoader, default=None
        each iterate is a training batch (X, y);
        `train_loader.dataset` needs to be set to access \\(N\\), size of the data set

    Returns
    -------
    subnet_mask_indices : torch.LongTensor
        a vector of indices of the vectorized model parameters
        (i.e. `torch.nn.utils.parameters_to_vector(model.parameters())`)
        that define the subnetwork
    """
    if self._indices is not None:
        raise ValueError("Subnetwork mask already selected.")

    subnet_mask = self.get_subnet_mask(train_loader)
    self._indices = self.convert_subnet_mask_to_indices(subnet_mask)
    return self._indices

get_subnet_mask #

get_subnet_mask(train_loader: DataLoader) -> Tensor

Get the subnetwork mask identifying the specified parameters.

Source code in laplace/utils/subnetmask.py
def get_subnet_mask(self, train_loader: DataLoader) -> torch.Tensor:
    """Get the subnetwork mask identifying the specified parameters."""

    self._check_param_names()

    subnet_mask_list = []
    for name, param in self.model.named_parameters():
        if name in self._parameter_names:
            mask_method = torch.ones_like
        else:
            mask_method = torch.zeros_like
        subnet_mask_list.append(mask_method(parameters_to_vector(param)))
    subnet_mask = torch.cat(subnet_mask_list).bool()
    return subnet_mask

ModuleNameSubnetMask #

ModuleNameSubnetMask(model: Module, module_names: list[str])

Bases: SubnetMask

Subnetwork mask corresponding to the specified modules of the neural network.

Parameters:

  • model (Module) –
  • parameter_names

    list of names of the modules (as in model.named_modules()) that define the subnetwork; the modules cannot have children, i.e. need to be leaf modules

Source code in laplace/utils/subnetmask.py
def __init__(self, model: nn.Module, module_names: list[str]):
    super().__init__(model)
    self._module_names: list[str] = module_names
    self._n_params_subnet: int | None = None

convert_subnet_mask_to_indices #

convert_subnet_mask_to_indices(subnet_mask: Tensor) -> LongTensor

Converts a subnetwork mask into subnetwork indices.

Parameters:

  • subnet_mask (Tensor) –

    a binary vector of size (n_params) where 1s locate the subnetwork parameters within the vectorized model parameters (i.e. torch.nn.utils.parameters_to_vector(model.parameters()))

Returns:

  • subnet_mask_indices ( LongTensor ) –

    a vector of indices of the vectorized model parameters (i.e. torch.nn.utils.parameters_to_vector(model.parameters())) that define the subnetwork

Source code in laplace/utils/subnetmask.py
def convert_subnet_mask_to_indices(
    self, subnet_mask: torch.Tensor
) -> torch.LongTensor:
    """Converts a subnetwork mask into subnetwork indices.

    Parameters
    ----------
    subnet_mask : torch.Tensor
        a binary vector of size (n_params) where 1s locate the subnetwork parameters
        within the vectorized model parameters
        (i.e. `torch.nn.utils.parameters_to_vector(model.parameters())`)

    Returns
    -------
    subnet_mask_indices : torch.LongTensor
        a vector of indices of the vectorized model parameters
        (i.e. `torch.nn.utils.parameters_to_vector(model.parameters())`)
        that define the subnetwork
    """
    if not isinstance(subnet_mask, torch.Tensor):
        raise ValueError("Subnetwork mask needs to be torch.Tensor!")
    elif (
        subnet_mask.dtype
        not in [
            torch.int64,
            torch.int32,
            torch.int16,
            torch.int8,
            torch.uint8,
            torch.bool,
        ]
        or len(subnet_mask.shape) != 1
    ):
        raise ValueError(
            "Subnetwork mask needs to be 1-dimensional integral or boolean tensor!"
        )
    elif (
        len(subnet_mask) != self._n_params
        or len(subnet_mask[subnet_mask == 0]) + len(subnet_mask[subnet_mask == 1])
        != self._n_params
    ):
        raise ValueError(
            "Subnetwork mask needs to be a binary vector of"
            "size (n_params) where 1s locate the subnetwork"
            "parameters within the vectorized model parameters"
            "(i.e. `torch.nn.utils.parameters_to_vector(model.parameters())`)!"
        )

    subnet_mask_indices = subnet_mask.nonzero(as_tuple=True)[0]
    return subnet_mask_indices

select #

select(train_loader: DataLoader | None = None) -> LongTensor

Select the subnetwork mask.

Parameters:

  • train_loader (DataLoader, default: None ) –

    each iterate is a training batch (X, y); train_loader.dataset needs to be set to access \(N\), size of the data set

Returns:

  • subnet_mask_indices ( LongTensor ) –

    a vector of indices of the vectorized model parameters (i.e. torch.nn.utils.parameters_to_vector(model.parameters())) that define the subnetwork

Source code in laplace/utils/subnetmask.py
def select(self, train_loader: DataLoader | None = None) -> torch.LongTensor:
    """Select the subnetwork mask.

    Parameters
    ----------
    train_loader : torch.data.utils.DataLoader, default=None
        each iterate is a training batch (X, y);
        `train_loader.dataset` needs to be set to access \\(N\\), size of the data set

    Returns
    -------
    subnet_mask_indices : torch.LongTensor
        a vector of indices of the vectorized model parameters
        (i.e. `torch.nn.utils.parameters_to_vector(model.parameters())`)
        that define the subnetwork
    """
    if self._indices is not None:
        raise ValueError("Subnetwork mask already selected.")

    subnet_mask = self.get_subnet_mask(train_loader)
    self._indices = self.convert_subnet_mask_to_indices(subnet_mask)
    return self._indices

get_subnet_mask #

get_subnet_mask(train_loader: DataLoader) -> Tensor

Get the subnetwork mask identifying the specified modules.

Source code in laplace/utils/subnetmask.py
def get_subnet_mask(self, train_loader: DataLoader) -> torch.Tensor:
    """Get the subnetwork mask identifying the specified modules."""

    self._check_module_names()

    subnet_mask_list = []
    for name, module in self.model.named_modules():
        if len(list(module.children())) > 0 or len(list(module.parameters())) == 0:
            continue
        if name in self._module_names:
            mask_method = torch.ones_like
        else:
            mask_method = torch.zeros_like
        subnet_mask_list.append(
            mask_method(parameters_to_vector(module.parameters()))
        )
    subnet_mask = torch.cat(subnet_mask_list).bool()
    return subnet_mask

LastLayerSubnetMask #

LastLayerSubnetMask(model: Module, last_layer_name: str | None = None)

Bases: ModuleNameSubnetMask

Subnetwork mask corresponding to the last layer of the neural network.

Parameters:

  • model (Module) –
  • last_layer_name (str | None, default: None ) –

    name of the model's last layer, if None it will be determined automatically

Source code in laplace/utils/subnetmask.py
def __init__(self, model: nn.Module, last_layer_name: str | None = None):
    super().__init__(model, [])
    self._feature_extractor: FeatureExtractor = FeatureExtractor(
        self.model, last_layer_name=last_layer_name
    )
    self._n_params_subnet: int | None = None

convert_subnet_mask_to_indices #

convert_subnet_mask_to_indices(subnet_mask: Tensor) -> LongTensor

Converts a subnetwork mask into subnetwork indices.

Parameters:

  • subnet_mask (Tensor) –

    a binary vector of size (n_params) where 1s locate the subnetwork parameters within the vectorized model parameters (i.e. torch.nn.utils.parameters_to_vector(model.parameters()))

Returns:

  • subnet_mask_indices ( LongTensor ) –

    a vector of indices of the vectorized model parameters (i.e. torch.nn.utils.parameters_to_vector(model.parameters())) that define the subnetwork

Source code in laplace/utils/subnetmask.py
def convert_subnet_mask_to_indices(
    self, subnet_mask: torch.Tensor
) -> torch.LongTensor:
    """Converts a subnetwork mask into subnetwork indices.

    Parameters
    ----------
    subnet_mask : torch.Tensor
        a binary vector of size (n_params) where 1s locate the subnetwork parameters
        within the vectorized model parameters
        (i.e. `torch.nn.utils.parameters_to_vector(model.parameters())`)

    Returns
    -------
    subnet_mask_indices : torch.LongTensor
        a vector of indices of the vectorized model parameters
        (i.e. `torch.nn.utils.parameters_to_vector(model.parameters())`)
        that define the subnetwork
    """
    if not isinstance(subnet_mask, torch.Tensor):
        raise ValueError("Subnetwork mask needs to be torch.Tensor!")
    elif (
        subnet_mask.dtype
        not in [
            torch.int64,
            torch.int32,
            torch.int16,
            torch.int8,
            torch.uint8,
            torch.bool,
        ]
        or len(subnet_mask.shape) != 1
    ):
        raise ValueError(
            "Subnetwork mask needs to be 1-dimensional integral or boolean tensor!"
        )
    elif (
        len(subnet_mask) != self._n_params
        or len(subnet_mask[subnet_mask == 0]) + len(subnet_mask[subnet_mask == 1])
        != self._n_params
    ):
        raise ValueError(
            "Subnetwork mask needs to be a binary vector of"
            "size (n_params) where 1s locate the subnetwork"
            "parameters within the vectorized model parameters"
            "(i.e. `torch.nn.utils.parameters_to_vector(model.parameters())`)!"
        )

    subnet_mask_indices = subnet_mask.nonzero(as_tuple=True)[0]
    return subnet_mask_indices

select #

select(train_loader: DataLoader | None = None) -> LongTensor

Select the subnetwork mask.

Parameters:

  • train_loader (DataLoader, default: None ) –

    each iterate is a training batch (X, y); train_loader.dataset needs to be set to access \(N\), size of the data set

Returns:

  • subnet_mask_indices ( LongTensor ) –

    a vector of indices of the vectorized model parameters (i.e. torch.nn.utils.parameters_to_vector(model.parameters())) that define the subnetwork

Source code in laplace/utils/subnetmask.py
def select(self, train_loader: DataLoader | None = None) -> torch.LongTensor:
    """Select the subnetwork mask.

    Parameters
    ----------
    train_loader : torch.data.utils.DataLoader, default=None
        each iterate is a training batch (X, y);
        `train_loader.dataset` needs to be set to access \\(N\\), size of the data set

    Returns
    -------
    subnet_mask_indices : torch.LongTensor
        a vector of indices of the vectorized model parameters
        (i.e. `torch.nn.utils.parameters_to_vector(model.parameters())`)
        that define the subnetwork
    """
    if self._indices is not None:
        raise ValueError("Subnetwork mask already selected.")

    subnet_mask = self.get_subnet_mask(train_loader)
    self._indices = self.convert_subnet_mask_to_indices(subnet_mask)
    return self._indices

get_subnet_mask #

get_subnet_mask(train_loader: DataLoader) -> Tensor

Get the subnetwork mask identifying the last layer.

Source code in laplace/utils/subnetmask.py
def get_subnet_mask(self, train_loader: DataLoader) -> torch.Tensor:
    """Get the subnetwork mask identifying the last layer."""
    if train_loader is None:
        raise ValueError("Need to pass train loader for subnet selection.")

    self._feature_extractor.eval()
    if self._feature_extractor.last_layer is None:
        X = next(iter(train_loader))[0]
        with torch.no_grad():
            self._feature_extractor.find_last_layer(X[:1].to(self._device))
    self._module_names = [self._feature_extractor._last_layer_name]

    return super().get_subnet_mask(train_loader)

RunningNLLMetric #

RunningNLLMetric(ignore_index: int = -100)

Bases: Metric

NLL metrics that

Parameters:

  • ignore_index (int, default: -100 ) –

    which class label to ignore when computing the NLL loss

Source code in laplace/utils/metrics.py
def __init__(self, ignore_index: int = -100) -> None:
    super().__init__()
    self.add_state("nll_sum", default=torch.tensor(0.0), dist_reduce_fx="sum")
    self.add_state(
        "n_valid_labels", default=torch.tensor(0.0), dist_reduce_fx="sum"
    )
    self.ignore_index: int = ignore_index

update #

update(probs: Tensor, targets: Tensor) -> None

Parameters:

  • probs (Tensor) –

    probability tensor of shape (..., n_classes)

  • targets (Tensor) –

    integer tensor of shape (...)

Source code in laplace/utils/metrics.py
def update(self, probs: torch.Tensor, targets: torch.Tensor) -> None:
    """
    Parameters
    ----------
    probs: torch.Tensor
        probability tensor of shape (..., n_classes)

    targets: torch.Tensor
        integer tensor of shape (...)
    """
    probs = probs.view(-1, probs.shape[-1])
    targets = targets.view(-1)

    self.nll_sum += F.nll_loss(
        probs.log(), targets, ignore_index=self.ignore_index, reduction="sum"
    )
    self.n_valid_labels += (targets != self.ignore_index).sum()

get_nll #

get_nll(out_dist: Tensor, targets: Tensor) -> Tensor
Source code in laplace/utils/utils.py
def get_nll(out_dist: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
    return F.nll_loss(torch.log(out_dist), targets)

validate #

validate(laplace: BaseLaplace, val_loader: DataLoader, loss: Metric | Callable[[Tensor, Tensor], Tensor] | Callable[[Tensor, Tensor, Tensor], Tensor], pred_type: PredType | str = PredType.GLM, link_approx: LinkApprox | str = LinkApprox.PROBIT, n_samples: int = 100, dict_key_y: str = 'labels') -> float
Source code in laplace/utils/utils.py
@torch.no_grad()
def validate(
    laplace: laplace.baselaplace.BaseLaplace,
    val_loader: DataLoader,
    loss: torchmetrics.Metric
    | Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
    | Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
    pred_type: PredType | str = PredType.GLM,
    link_approx: LinkApprox | str = LinkApprox.PROBIT,
    n_samples: int = 100,
    dict_key_y: str = "labels",
) -> float:
    laplace.model.eval()
    assert callable(loss) or isinstance(loss, Metric)
    is_offline = not isinstance(loss, Metric)

    if is_offline:
        output_means, output_vars = list(), list()
        targets = list()

    for data in val_loader:
        if isinstance(data, MutableMapping):
            X, y = data, data[dict_key_y]
        else:
            X, y = data
            X = X.to(laplace._device)
        y = y.to(laplace._device)
        out = laplace(
            X,
            pred_type=pred_type,
            link_approx=link_approx,
            n_samples=n_samples,
            fitting=True,
        )

        if type(out) is tuple:
            if is_offline:
                output_means.append(out[0])
                output_vars.append(out[1])
                targets.append(y)
            else:
                try:
                    loss.update(*out, y)
                except TypeError:  # If the online loss only accepts 2 args
                    loss.update(out[0], y)
        else:
            if is_offline:
                output_means.append(out)
                targets.append(y)
            else:
                loss.update(out, y)

    if is_offline:
        if len(output_vars) == 0:
            preds, targets = torch.cat(output_means, dim=0), torch.cat(targets, dim=0)
            return loss(preds, targets).item()

        means, variances = torch.cat(output_means, dim=0), torch.cat(output_vars, dim=0)
        targets = torch.cat(targets, dim=0)
        return loss(means, variances, targets).item()
    else:
        # Aggregate since torchmetrics output n_classes values for the MSE metric
        return loss.compute().sum().item()

parameters_per_layer #

parameters_per_layer(model: Module) -> list[int]

Get number of parameters per layer.

Parameters:

  • model (Module) –

Returns:

  • params_per_layer ( list[int] ) –
Source code in laplace/utils/utils.py
def parameters_per_layer(model: nn.Module) -> list[int]:
    """Get number of parameters per layer.

    Parameters
    ----------
    model : torch.nn.Module

    Returns
    -------
    params_per_layer : list[int]
    """
    return [np.prod(p.shape) for p in model.parameters()]

invsqrt_precision #

invsqrt_precision(M: Tensor) -> Tensor

Compute M^{-0.5} as a tridiagonal matrix.

Parameters:

  • M (Tensor) –

Returns:

  • M_invsqrt ( Tensor ) –
Source code in laplace/utils/utils.py
def invsqrt_precision(M: torch.Tensor) -> torch.Tensor:
    """Compute ``M^{-0.5}`` as a tridiagonal matrix.

    Parameters
    ----------
    M : torch.Tensor

    Returns
    -------
    M_invsqrt : torch.Tensor
    """
    return _precision_to_scale_tril(M)

kron #

kron(t1: Tensor, t2: Tensor) -> Tensor

Computes the Kronecker product between two tensors.

Parameters:

  • t1 (Tensor) –
  • t2 (Tensor) –

Returns:

  • kron_product ( Tensor ) –
Source code in laplace/utils/utils.py
def kron(t1: torch.Tensor, t2: torch.Tensor) -> torch.Tensor:
    """Computes the Kronecker product between two tensors.

    Parameters
    ----------
    t1 : torch.Tensor
    t2 : torch.Tensor

    Returns
    -------
    kron_product : torch.Tensor
    """
    t1_height, t1_width = t1.size()
    t2_height, t2_width = t2.size()
    out_height = t1_height * t2_height
    out_width = t1_width * t2_width

    tiled_t2 = t2.repeat(t1_height, t1_width)
    expanded_t1 = (
        t1.unsqueeze(2)
        .unsqueeze(3)
        .repeat(1, t2_height, t2_width, 1)
        .view(out_height, out_width)
    )

    return expanded_t1 * tiled_t2

diagonal_add_scalar #

diagonal_add_scalar(X: Tensor, value: Tensor) -> Tensor

Add scalar value value to diagonal of X.

Parameters:

  • X (Tensor) –
  • value (Tensor or float) –

Returns:

  • X_add_scalar ( Tensor ) –
Source code in laplace/utils/utils.py
def diagonal_add_scalar(X: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
    """Add scalar value `value` to diagonal of `X`.

    Parameters
    ----------
    X : torch.Tensor
    value : torch.Tensor or float

    Returns
    -------
    X_add_scalar : torch.Tensor
    """
    indices = torch.LongTensor([[i, i] for i in range(X.shape[0])], device=X.device)
    values = X.new_ones(X.shape[0]).mul(value)
    return X.index_put(tuple(indices.t()), values, accumulate=True)

symeig #

symeig(M: Tensor) -> tuple[Tensor, Tensor]

Symetric eigendecomposition avoiding failure cases by adding and removing jitter to the diagonal.

Parameters:

  • M (Tensor) –

Returns:

  • L ( Tensor ) –

    eigenvalues

  • W ( Tensor ) –

    eigenvectors

Source code in laplace/utils/utils.py
def symeig(M: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    """Symetric eigendecomposition avoiding failure cases by
    adding and removing jitter to the diagonal.

    Parameters
    ----------
    M : torch.Tensor

    Returns
    -------
    L : torch.Tensor
        eigenvalues
    W : torch.Tensor
        eigenvectors
    """
    try:
        L, W = torch.linalg.eigh(M, UPLO="U")
    except RuntimeError:  # did not converge
        logging.info("SYMEIG: adding jitter, did not converge.")
        # use W L W^T + I = W (L + I) W^T
        M = M + torch.eye(M.shape[0], device=M.device)
        try:
            L, W = torch.linalg.eigh(M, UPLO="U")
            L -= 1.0
        except RuntimeError:
            stats = f"diag: {M.diagonal()}, max: {M.abs().max()}, "
            stats = stats + f"min: {M.abs().min()}, mean: {M.abs().mean()}"
            logging.info(f"SYMEIG: adding jitter failed. Stats: {stats}")
            exit()
    # eigenvalues of symeig at least 0
    L = L.clamp(min=0.0)
    L = torch.nan_to_num(L)
    W = torch.nan_to_num(W)
    return L, W

block_diag #

block_diag(blocks: list[Tensor]) -> Tensor

Compose block-diagonal matrix of individual blocks.

Parameters:

  • blocks (list[Tensor]) –

Returns:

  • M ( Tensor ) –
Source code in laplace/utils/utils.py
def block_diag(blocks: list[torch.Tensor]) -> torch.Tensor:
    """Compose block-diagonal matrix of individual blocks.

    Parameters
    ----------
    blocks : list[torch.Tensor]

    Returns
    -------
    M : torch.Tensor
    """
    P = sum([b.shape[0] for b in blocks])
    M = torch.zeros(P, P, dtype=blocks[0].dtype, device=blocks[0].device)
    p_cur = 0
    for block in blocks:
        p_block = block.shape[0]
        M[p_cur : p_cur + p_block, p_cur : p_cur + p_block] = block
        p_cur += p_block
    return M

normal_samples #

normal_samples(mean: Tensor, var: Tensor, n_samples: int, generator: Generator | None = None) -> Tensor

Produce samples from a batch of Normal distributions either parameterized by a diagonal or full covariance given by var.

Parameters:

  • mean (Tensor) –

    (batch_size, output_dim)

  • var (Tensor) –

    (co)variance of the Normal distribution (batch_size, output_dim, output_dim) or (batch_size, output_dim)

  • generator (Generator, default: None ) –

    random number generator

Source code in laplace/utils/utils.py
def normal_samples(
    mean: torch.Tensor,
    var: torch.Tensor,
    n_samples: int,
    generator: torch.Generator | None = None,
) -> torch.Tensor:
    """Produce samples from a batch of Normal distributions either parameterized
    by a diagonal or full covariance given by `var`.

    Parameters
    ----------
    mean : torch.Tensor
        `(batch_size, output_dim)`
    var : torch.Tensor
        (co)variance of the Normal distribution
        `(batch_size, output_dim, output_dim)` or `(batch_size, output_dim)`
    generator : torch.Generator
        random number generator
    """
    assert mean.ndim == 2, "Invalid input shape of mean, should be 2-dimensional."
    _, output_dim = mean.shape
    randn_samples = torch.randn(
        (output_dim, n_samples),
        device=mean.device,
        dtype=mean.dtype,
        generator=generator,
    )

    if mean.shape == var.shape:
        # diagonal covariance
        scaled_samples = var.sqrt().unsqueeze(-1) * randn_samples.unsqueeze(0)
        return (mean.unsqueeze(-1) + scaled_samples).permute((2, 0, 1))
    elif mean.shape == var.shape[:2] and var.shape[-1] == mean.shape[1]:
        # full covariance
        scale = torch.linalg.cholesky(var)
        scaled_samples = torch.matmul(
            scale, randn_samples.unsqueeze(0)
        )  # expand batch dim
        return (mean.unsqueeze(-1) + scaled_samples).permute((2, 0, 1))
    else:
        raise ValueError("Invalid input shapes.")

_is_batchnorm #

_is_batchnorm(module: Module) -> bool
Source code in laplace/utils/utils.py
def _is_batchnorm(module: nn.Module) -> bool:
    if isinstance(module, (BatchNorm1d, BatchNorm2d, BatchNorm3d)):
        return True
    return False

_is_valid_scalar #

_is_valid_scalar(scalar: float | int | Tensor) -> bool
Source code in laplace/utils/utils.py
def _is_valid_scalar(scalar: float | int | torch.Tensor) -> bool:
    if np.isscalar(scalar) and np.isreal(scalar):
        return True
    elif torch.is_tensor(scalar) and scalar.ndim <= 1:
        if scalar.ndim == 1 and len(scalar) != 1:
            return False
        return True
    return False

expand_prior_precision #

expand_prior_precision(prior_prec: Tensor, model: Module) -> Tensor

Expand prior precision to match the shape of the model parameters.

Parameters:

  • prior_prec (torch.Tensor 1-dimensional) –

    prior precision

  • model (Module) –

    torch model with parameters that are regularized by prior_prec

Returns:

  • expanded_prior_prec ( Tensor ) –

    expanded prior precision has the same shape as model parameters

Source code in laplace/utils/utils.py
def expand_prior_precision(prior_prec: torch.Tensor, model: nn.Module) -> torch.Tensor:
    """Expand prior precision to match the shape of the model parameters.

    Parameters
    ----------
    prior_prec : torch.Tensor 1-dimensional
        prior precision
    model : torch.nn.Module
        torch model with parameters that are regularized by prior_prec

    Returns
    -------
    expanded_prior_prec : torch.Tensor
        expanded prior precision has the same shape as model parameters
    """
    trainable_params = [p for p in model.parameters() if p.requires_grad]
    theta = parameters_to_vector(trainable_params)
    device, P = theta.device, len(theta)
    assert prior_prec.ndim == 1
    if len(prior_prec) == 1:  # scalar
        return torch.ones(P, device=device) * prior_prec
    elif len(prior_prec) == P:  # full diagonal
        return prior_prec.to(device)
    else:
        return torch.cat(
            [
                delta * torch.ones_like(m).flatten()
                for delta, m in zip(prior_prec, trainable_params)
            ]
        )

fix_prior_prec_structure #

fix_prior_prec_structure(prior_prec_init: Tensor, prior_structure: PriorStructure | str, n_layers: int, n_params: int, device: device) -> Tensor

Create a tensor of prior precision with the correct shape, depending on the choice of the prior structure type.

Parameters:

  • prior_prec_init (Tensor) –

    the initial prior precision tensor (could be scalar)

  • prior_structure (PriorStructure | str) –

    the choice of the prior structure type

  • n_layers (int) –
  • n_params (int) –
  • device (device) –

Returns:

  • correct_prior_precision ( Tensor ) –
Source code in laplace/utils/utils.py
def fix_prior_prec_structure(
    prior_prec_init: torch.Tensor,
    prior_structure: PriorStructure | str,
    n_layers: int,
    n_params: int,
    device: torch.device,
) -> torch.Tensor:
    """Create a tensor of prior precision with the correct shape, depending on the
    choice of the prior structure type.

    Parameters
    ----------
    prior_prec_init: torch.Tensor
        the initial prior precision tensor (could be scalar)
    prior_structure: PriorStructure | str
        the choice of the prior structure type
    n_layers: int
    n_params: int
    device: torch.device

    Returns
    -------
    correct_prior_precision: torch.Tensor
    """
    if prior_structure == PriorStructure.SCALAR:
        prior_prec_init = torch.full((1,), prior_prec_init, device=device)
    elif prior_structure == PriorStructure.LAYERWISE:
        prior_prec_init = torch.full((n_layers,), prior_prec_init, device=device)
    elif prior_structure == PriorStructure.DIAG:
        prior_prec_init = torch.full((n_params,), prior_prec_init, device=device)
    else:
        raise ValueError(f"Invalid prior structure {prior_structure}.")
    return prior_prec_init

fit_diagonal_swag_var #

fit_diagonal_swag_var(model: Module, train_loader: DataLoader, criterion: CrossEntropyLoss | MSELoss, n_snapshots_total: int = 40, snapshot_freq: int = 1, lr: float = 0.01, momentum: float = 0.9, weight_decay: float = 0.0003, min_var: float = 1e-30) -> Tensor

Fit diagonal SWAG [1], which estimates marginal variances of model parameters by computing the first and second moment of SGD iterates with a large learning rate.

Implementation partly adapted from: - https://github.com/wjmaddox/swa_gaussian/blob/master/swag/posteriors/swag.py - https://github.com/wjmaddox/swa_gaussian/blob/master/experiments/train/run_swag.py

References

[1] Maddox, W., Garipov, T., Izmailov, P., Vetrov, D., Wilson, AG. A Simple Baseline for Bayesian Uncertainty in Deep Learning. NeurIPS 2019.

Parameters:

  • model (Module) –
  • train_loader (DataLoader) –

    training data loader to use for snapshot collection

  • criterion (CrossEntropyLoss or MSELoss) –

    loss function to use for snapshot collection

  • n_snapshots_total (int, default: 40 ) –

    total number of model snapshots to collect

  • snapshot_freq (int, default: 1 ) –

    snapshot collection frequency (in epochs)

  • lr (float, default: 0.01 ) –

    SGD learning rate for collecting snapshots

  • momentum (float, default: 0.9 ) –

    SGD momentum

  • weight_decay (float, default: 0.0003 ) –

    SGD weight decay

  • min_var (float, default: 1e-30 ) –

    minimum parameter variance to clamp to (for numerical stability)

Returns:

  • param_variances ( Tensor ) –

    vector of marginal variances for each model parameter

Source code in laplace/utils/swag.py
def fit_diagonal_swag_var(
    model: nn.Module,
    train_loader: DataLoader,
    criterion: nn.CrossEntropyLoss | nn.MSELoss,
    n_snapshots_total: int = 40,
    snapshot_freq: int = 1,
    lr: float = 0.01,
    momentum: float = 0.9,
    weight_decay: float = 3e-4,
    min_var: float = 1e-30,
) -> torch.Tensor:
    """
    Fit diagonal SWAG [1], which estimates marginal variances of model parameters by
    computing the first and second moment of SGD iterates with a large learning rate.

    Implementation partly adapted from:
    - https://github.com/wjmaddox/swa_gaussian/blob/master/swag/posteriors/swag.py
    - https://github.com/wjmaddox/swa_gaussian/blob/master/experiments/train/run_swag.py

    References
    ----------
    [1] Maddox, W., Garipov, T., Izmailov, P., Vetrov, D., Wilson, AG.
    [*A Simple Baseline for Bayesian Uncertainty in Deep Learning*](https://arxiv.org/abs/1902.02476).
    NeurIPS 2019.

    Parameters
    ----------
    model : torch.nn.Module
    train_loader : torch.data.utils.DataLoader
        training data loader to use for snapshot collection
    criterion : torch.nn.CrossEntropyLoss or torch.nn.MSELoss
        loss function to use for snapshot collection
    n_snapshots_total : int
        total number of model snapshots to collect
    snapshot_freq : int
        snapshot collection frequency (in epochs)
    lr : float
        SGD learning rate for collecting snapshots
    momentum : float
        SGD momentum
    weight_decay : float
        SGD weight decay
    min_var : float
        minimum parameter variance to clamp to (for numerical stability)

    Returns
    -------
    param_variances : torch.Tensor
        vector of marginal variances for each model parameter
    """

    # create a copy of the model to avoid undesired changes to the original model parameters
    _model: nn.Module = deepcopy(model)
    _model.train()
    device: torch.device = next(_model.parameters()).device

    # initialize running estimates of first and second moment of model parameters
    mean: torch.Tensor = torch.zeros_like(_param_vector(_model))
    sq_mean: torch.Tensor = torch.zeros_like(_param_vector(_model))
    n_snapshots: int = 0

    # run SGD to collect model snapshots
    optimizer: Optimizer = torch.optim.SGD(
        _model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay
    )
    n_epochs: int = snapshot_freq * n_snapshots_total

    for epoch in range(n_epochs):
        for inputs, targets in train_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            loss = criterion(_model(inputs), targets)
            loss.backward()
            optimizer.step()

        if epoch % snapshot_freq == 0:
            # update running estimates of first and second moment of model parameters
            old_fac, new_fac = n_snapshots / (n_snapshots + 1), 1 / (n_snapshots + 1)
            mean = mean * old_fac + _param_vector(_model) * new_fac
            sq_mean = sq_mean * old_fac + _param_vector(_model) ** 2 * new_fac
            n_snapshots += 1

    # compute marginal parameter variances, Var[P] = E[P^2] - E[P]^2
    param_variances: torch.Tensor = torch.clamp(sq_mean - mean**2, min_var)
    return param_variances