Module laplace.utils.swag
Functions
def fit_diagonal_swag_var(model, train_loader, criterion, n_snapshots_total=40, snapshot_freq=1, lr=0.01, momentum=0.9, weight_decay=0.0003, min_var=1e-30)
-
Fit diagonal SWAG [1], which estimates marginal variances of model parameters by computing the first and second moment of SGD iterates with a large learning rate.
Implementation partly adapted from: - https://github.com/wjmaddox/swa_gaussian/blob/master/swag/posteriors/swag.py - https://github.com/wjmaddox/swa_gaussian/blob/master/experiments/train/run_swag.py
References
[1] Maddox, W., Garipov, T., Izmailov, P., Vetrov, D., Wilson, AG. A Simple Baseline for Bayesian Uncertainty in Deep Learning. NeurIPS 2019.
Parameters
model
:torch.nn.Module
train_loader
:torch.data.utils.DataLoader
- training data loader to use for snapshot collection
criterion
:torch.nn.CrossEntropyLoss
ortorch.nn.MSELoss
- loss function to use for snapshot collection
n_snapshots_total
:int
- total number of model snapshots to collect
snapshot_freq
:int
- snapshot collection frequency (in epochs)
lr
:float
- SGD learning rate for collecting snapshots
momentum
:float
- SGD momentum
weight_decay
:float
- SGD weight decay
min_var
:float
- minimum parameter variance to clamp to (for numerical stability)
Returns
param_variances
:torch.Tensor
- vector of marginal variances for each model parameter