Module laplace.utils.subnetmask

Classes

class SubnetMask (model: nn.Module)

Baseclass for all subnetwork masks in this library (for subnetwork Laplace).

Parameters

model : torch.nn.Module
 

Subclasses

Instance variables

var indices : torch.LongTensor
var n_params_subnet : int

Methods

def convert_subnet_mask_to_indices(self, subnet_mask: torch.Tensor) ‑> torch.LongTensor

Converts a subnetwork mask into subnetwork indices.

Parameters

subnet_mask : torch.Tensor
a binary vector of size (n_params) where 1s locate the subnetwork parameters within the vectorized model parameters (i.e. torch.nn.utils.parameters_to_vector(model.parameters()))

Returns

subnet_mask_indices : torch.LongTensor
a vector of indices of the vectorized model parameters (i.e. torch.nn.utils.parameters_to_vector(model.parameters())) that define the subnetwork
def select(self, train_loader: DataLoader | None = None) ‑> torch.LongTensor

Select the subnetwork mask.

Parameters

train_loader : torch.data.utils.DataLoader, default=None
each iterate is a training batch (X, y); train_loader.dataset needs to be set to access N, size of the data set

Returns

subnet_mask_indices : torch.LongTensor
a vector of indices of the vectorized model parameters (i.e. torch.nn.utils.parameters_to_vector(model.parameters())) that define the subnetwork
def get_subnet_mask(self, train_loader: DataLoader) ‑> torch.Tensor

Get the subnetwork mask.

Parameters

train_loader : torch.data.utils.DataLoader
each iterate is a training batch (X, y); train_loader.dataset needs to be set to access N, size of the data set

Returns

subnet_mask : torch.Tensor
a binary vector of size (n_params) where 1s locate the subnetwork parameters within the vectorized model parameters (i.e. torch.nn.utils.parameters_to_vector(model.parameters()))
class RandomSubnetMask (model: nn.Module, n_params_subnet: int)

Subnetwork mask of parameters sampled uniformly at random.

Ancestors

  • laplace.utils.subnetmask.ScoreBasedSubnetMask
  • SubnetMask

Methods

def compute_param_scores(self, train_loader: DataLoader) ‑> torch.Tensor

Inherited members

class LargestMagnitudeSubnetMask (model: nn.Module, n_params_subnet: int)

Subnetwork mask identifying the parameters with the largest magnitude.

Ancestors

  • laplace.utils.subnetmask.ScoreBasedSubnetMask
  • SubnetMask

Methods

def compute_param_scores(self, train_loader: DataLoader) ‑> torch.Tensor

Inherited members

class LargestVarianceDiagLaplaceSubnetMask (model: nn.Module, n_params_subnet: int, diag_laplace_model: DiagLaplace)

Subnetwork mask identifying the parameters with the largest marginal variances (estimated using a diagonal Laplace approximation over all model parameters).

Parameters

model : torch.nn.Module
 
n_params_subnet : int
number of parameters in the subnetwork (i.e. number of top-scoring parameters to select)
diag_laplace_model : DiagLaplace
diagonal Laplace model to use for variance estimation

Ancestors

  • laplace.utils.subnetmask.ScoreBasedSubnetMask
  • SubnetMask

Methods

def compute_param_scores(self, train_loader: DataLoader) ‑> torch.Tensor

Inherited members

class LargestVarianceSWAGSubnetMask (model: nn.Module, n_params_subnet: int, likelihood: Likelihood | str = Likelihood.CLASSIFICATION, swag_n_snapshots: int = 40, swag_snapshot_freq: int = 1, swag_lr: float = 0.01)

Subnetwork mask identifying the parameters with the largest marginal variances (estimated using diagonal SWAG over all model parameters).

Parameters

model : torch.nn.Module
 
n_params_subnet : int
number of parameters in the subnetwork (i.e. number of top-scoring parameters to select)
likelihood : str
'classification' or 'regression'
swag_n_snapshots : int
number of model snapshots to collect for SWAG
swag_snapshot_freq : int
SWAG snapshot collection frequency (in epochs)
swag_lr : float
learning rate for SWAG snapshot collection

Ancestors

  • laplace.utils.subnetmask.ScoreBasedSubnetMask
  • SubnetMask

Methods

def compute_param_scores(self, train_loader: DataLoader) ‑> torch.Tensor

Inherited members

class ParamNameSubnetMask (model: nn.Module, parameter_names: list[str])

Subnetwork mask corresponding to the specified parameters of the neural network.

Parameters

model : torch.nn.Module
 
parameter_names : List[str]
list of names of the parameters (as in model.named_parameters()) that define the subnetwork

Ancestors

Methods

def get_subnet_mask(self, train_loader: DataLoader) ‑> torch.Tensor

Get the subnetwork mask identifying the specified parameters.

Inherited members

class ModuleNameSubnetMask (model: nn.Module, module_names: list[str])

Subnetwork mask corresponding to the specified modules of the neural network.

Parameters

model : torch.nn.Module
 
parameter_names : List[str]
list of names of the modules (as in model.named_modules()) that define the subnetwork; the modules cannot have children, i.e. need to be leaf modules

Ancestors

Subclasses

Methods

def get_subnet_mask(self, train_loader: DataLoader) ‑> torch.Tensor

Get the subnetwork mask identifying the specified modules.

Inherited members

class LastLayerSubnetMask (model: nn.Module, last_layer_name: str | None = None)

Subnetwork mask corresponding to the last layer of the neural network.

Parameters

model : torch.nn.Module
 
last_layer_name : str, default=None
name of the model's last layer, if None it will be determined automatically

Ancestors

Methods

def get_subnet_mask(self, train_loader: DataLoader) ‑> torch.Tensor

Get the subnetwork mask identifying the last layer.

Inherited members