Module laplace.laplace

Functions

def Laplace(model: torch.nn.Module, likelihood: Likelihood | str, subset_of_weights: SubsetOfWeights | str = SubsetOfWeights.LAST_LAYER, hessian_structure: HessianStructure | str = HessianStructure.KRON, *args, **kwargs)

Simplified Laplace access using strings instead of different classes.

Parameters

model : torch.nn.Module
 
likelihood : Likelihood or str in {'classification', 'regression'}
 
subset_of_weights : SubsetofWeights or {'last_layer', 'subnetwork', 'all'}, default=SubsetOfWeights.LAST_LAYER
subset of weights to consider for inference
hessian_structure : HessianStructure or str in {'diag', 'kron', 'full', 'lowrank', 'gp'}, default=HessianStructure.KRON
structure of the Hessian approximation (note that in case of 'gp', we are not actually doing any Hessian approximation, the inference is instead done in the functional space)

Returns

laplace : BaseLaplace
chosen subclass of BaseLaplace instantiated with additional arguments