Module laplace.utils.swag

Functions

def fit_diagonal_swag_var(model, train_loader, criterion, n_snapshots_total=40, snapshot_freq=1, lr=0.01, momentum=0.9, weight_decay=0.0003, min_var=1e-30)

Fit diagonal SWAG [1], which estimates marginal variances of model parameters by computing the first and second moment of SGD iterates with a large learning rate.

Implementation partly adapted from: - https://github.com/wjmaddox/swa_gaussian/blob/master/swag/posteriors/swag.py - https://github.com/wjmaddox/swa_gaussian/blob/master/experiments/train/run_swag.py

References

[1] Maddox, W., Garipov, T., Izmailov, P., Vetrov, D., Wilson, AG. A Simple Baseline for Bayesian Uncertainty in Deep Learning. NeurIPS 2019.

Parameters

model : torch.nn.Module
 
train_loader : torch.data.utils.DataLoader
training data loader to use for snapshot collection
criterion : torch.nn.CrossEntropyLoss or torch.nn.MSELoss
loss function to use for snapshot collection
n_snapshots_total : int
total number of model snapshots to collect
snapshot_freq : int
snapshot collection frequency (in epochs)
lr : float
SGD learning rate for collecting snapshots
momentum : float
SGD momentum
weight_decay : float
SGD weight decay
min_var : float
minimum parameter variance to clamp to (for numerical stability)

Returns

param_variances : torch.Tensor
vector of marginal variances for each model parameter