Module laplace.laplace

Functions

def Laplace(model: torch.nn.Module, likelihood: Likelihood | str, subset_of_weights: SubsetOfWeights | str = SubsetOfWeights.LAST_LAYER, hessian_structure: HessianStructure | str = HessianStructure.KRON, *args, **kwargs) ‑> ParametricLaplace

Simplified Laplace access using strings instead of different classes.

Parameters

model : torch.nn.Module
 
likelihood : Likelihood or str in {'classification', 'regression'}
 
subset_of_weights : SubsetofWeights or {'last_layer', 'subnetwork', 'all'}, default=SubsetOfWeights.LAST_LAYER
subset of weights to consider for inference
hessian_structure : HessianStructure or str in {'diag', 'kron', 'full', 'lowrank'}, default=HessianStructure.KRON
structure of the Hessian approximation

Returns

laplace : ParametricLaplace
chosen subclass of ParametricLaplace instantiated with additional arguments