Skip to content

laplace.utils.enums #

SubsetOfWeights #

Bases: str, Enum

Valid options for subset_of_weights.

ALL #

ALL = 'all'

All-layer, all-parameter Laplace.

LAST_LAYER #

LAST_LAYER = 'last_layer'

Last-layer Laplace.

SUBNETWORK #

SUBNETWORK = 'subnetwork'

Subnetwork Laplace.

HessianStructure #

Bases: str, Enum

Valid options for hessian_structure.

FULL #

FULL = 'full'

Full Hessian (generally very expensive).

KRON #

KRON = 'kron'

Kronecker-factored Hessian (preferrable).

DIAG #

DIAG = 'diag'

Diagonal Hessian.

LOWRANK #

LOWRANK = 'lowrank'

Low-rank Hessian.

GP #

GP = 'gp'

Functional Laplace.

Likelihood #

Bases: str, Enum

Valid options for likelihood.

REGRESSION #

REGRESSION = 'regression'

Homoskedastic regression, assuming loss_fn = nn.MSELoss().

CLASSIFICATION #

CLASSIFICATION = 'classification'

Classification, assuming loss_fn = nn.CrossEntropyLoss().

REWARD_MODELING #

REWARD_MODELING = 'reward_modeling'

Bradley-Terry likelihood, for preference learning / reward modeling.

PredType #

Bases: str, Enum

Valid options for pred_type.

GLM #

GLM = 'glm'

Linearized, closed-form predictive.

NN #

NN = 'nn'

Monte-Carlo predictive on the NN's weights.

GP #

GP = 'gp'

Gaussian-process predictive, done by inverting the kernel matrix.

LinkApprox #

Bases: str, Enum

Valid options for link_approx. Only works with likelihood = Likelihood.CLASSIFICATION.

MC #

MC = 'mc'

Monte-Carlo approximation in the function space on top of the GLM predictive.

PROBIT #

PROBIT = 'probit'

Closed-form multiclass probit approximation.

BRIDGE #

BRIDGE = 'bridge'

Closed-form Laplace Bridge approximation.

BRIDGE_NORM #

BRIDGE_NORM = 'bridge_norm'

Closed-form Laplace Bridge approximation with normalization factor. Preferable to BRIDGE.

TuningMethod #

Bases: str, Enum

Valid options for the method parameter in optimize_prior_precision.

MARGLIK #

MARGLIK = 'marglik'

Marginal-likelihood loss via SGD. Does not require validation data.

GRIDSEARCH #

GRIDSEARCH = 'gridsearch'

Grid search. Requires validation data.

PriorStructure #

Bases: str, Enum

Valid options for the prior_structure in optimize_prior_precision.

SCALAR #

SCALAR = 'scalar'

Scalar prior precision \( \tau I, \tau \in \mathbf{R} \).

DIAG #

DIAG = 'diag'

Scalar prior precision \( \tau \in \mathbb{R}^p \).

LAYERWISE #

LAYERWISE = 'layerwise'

Layerwise prior precision, i.e. a single scalar prior precision for each block (corresponding to each the NN's layer) of the diagonal prior-precision matrix..