Module laplace.laplace
Functions
def Laplace(model: torch.nn.Module, likelihood: Likelihood | str, subset_of_weights: SubsetOfWeights | str = SubsetOfWeights.LAST_LAYER, hessian_structure: HessianStructure | str = HessianStructure.KRON, *args, **kwargs)
-
Simplified Laplace access using strings instead of different classes.
Parameters
model
:torch.nn.Module
likelihood
:Likelihood
orstr in {'classification', 'regression'}
subset_of_weights
:SubsetofWeights
or{'last_layer', 'subnetwork', 'all'}
, default=SubsetOfWeights.LAST_LAYER
- subset of weights to consider for inference
hessian_structure
:HessianStructure
orstr in {'diag', 'kron', 'full', 'lowrank', 'gp'}
, default=HessianStructure.KRON
- structure of the Hessian approximation (note that in case of 'gp', we are not actually doing any Hessian approximation, the inference is instead done in the functional space)
Returns
laplace
:BaseLaplace
- chosen subclass of BaseLaplace instantiated with additional arguments