laplace.baselaplace
#
BaseLaplace
#
BaseLaplace(model: Module, likelihood: Likelihood | str, sigma_noise: float | Tensor = 1.0, prior_precision: float | Tensor = 1.0, prior_mean: float | Tensor = 0.0, temperature: float = 1.0, enable_backprop: bool = False, dict_key_x: str = 'input_ids', dict_key_y: str = 'labels', backend: type[CurvatureInterface] | None = None, backend_kwargs: dict[str, Any] | None = None, asdl_fisher_kwargs: dict[str, Any] | None = None)
Baseclass for all Laplace approximations in this library.
Parameters:
-
model
(Module
) – -
likelihood
(Likelihood or str in {'classification', 'regression', 'reward_modeling'}
) –determines the log likelihood Hessian approximation. In the case of 'reward_modeling', it fits Laplace using the classification likelihood, then does prediction as in regression likelihood. The model needs to be defined accordingly: The forward pass during training takes
x.shape == (batch_size, 2, dim)
withy.shape = (batch_size,)
. Meanwhile, during evaluationx.shape == (batch_size, dim)
. Note that 'reward_modeling' only supportsKronLaplace
andDiagLaplace
. -
sigma_noise
(Tensor or float
, default:1
) –observation noise for the regression setting; must be 1 for classification
-
prior_precision
(Tensor or float
, default:1
) –prior precision of a Gaussian prior (= weight decay); can be scalar, per-layer, or diagonal in the most general case
-
prior_mean
(Tensor or float
, default:0
) –prior mean of a Gaussian prior, useful for continual learning
-
temperature
(float
, default:1
) –temperature of the likelihood; lower temperature leads to more concentrated posterior and vice versa.
-
enable_backprop
(bool
, default:False
) –whether to enable backprop to the input
x
through the Laplace predictive. Useful for e.g. Bayesian optimization. -
dict_key_x
(str
, default:'input_ids'
) –The dictionary key under which the input tensor
x
is stored. Only has effect when the model takes aMutableMapping
as the input. Useful for Huggingface LLM models. -
dict_key_y
(str
, default:'labels'
) –The dictionary key under which the target tensor
y
is stored. Only has effect when the model takes aMutableMapping
as the input. Useful for Huggingface LLM models. -
backend
(subclasses of `laplace.curvature.CurvatureInterface`
, default:None
) –backend for access to curvature/Hessian approximations. Defaults to CurvlinopsGGN if None.
-
backend_kwargs
(dict
, default:None
) –arguments passed to the backend on initialization, for example to set the number of MC samples for stochastic approximations.
-
asdl_fisher_kwargs
(dict
, default:None
) –arguments passed to the ASDL backend specifically on initialization.
Source code in laplace/baselaplace.py
log_likelihood
#
Compute log likelihood on the training data after .fit()
has been called.
The log likelihood is computed on-demand based on the loss and, for example,
the observation noise which makes it differentiable in the latter for
iterative updates.
Returns:
-
log_likelihood
(Tensor
) –
prior_precision_diag
#
Obtain the diagonal prior precision \(p_0\) constructed from either a scalar, layer-wise, or diagonal prior precision.
Returns:
-
prior_precision_diag
(Tensor
) –
optimize_prior_precision
#
optimize_prior_precision(pred_type: PredType | str, method: TuningMethod | str = TuningMethod.MARGLIK, n_steps: int = 100, lr: float = 0.1, init_prior_prec: float | Tensor = 1.0, prior_structure: PriorStructure | str = PriorStructure.DIAG, val_loader: DataLoader | None = None, loss: Metric | Callable[[Tensor], Tensor | float] | None = None, log_prior_prec_min: float = -4, log_prior_prec_max: float = 4, grid_size: int = 100, link_approx: LinkApprox | str = LinkApprox.PROBIT, n_samples: int = 100, verbose: bool = False, progress_bar: bool = False) -> None
Optimize the prior precision post-hoc using the method
specified by the user.
Parameters:
-
pred_type
(PredType or str in {'glm', 'nn'}
) –type of posterior predictive, linearized GLM predictive or neural network sampling predictiv. The GLM predictive is consistent with the curvature approximations used here.
-
method
(TuningMethod or str in {'marglik', 'gridsearch'}
, default:PredType.MARGLIK
) –specifies how the prior precision should be optimized.
-
n_steps
(int
, default:100
) –the number of gradient descent steps to take.
-
lr
(float
, default:1e-1
) –the learning rate to use for gradient descent.
-
init_prior_prec
(float or tensor
, default:1.0
) –initial prior precision before the first optimization step.
-
prior_structure
(PriorStructure or str in {'scalar', 'layerwise', 'diag'}
, default:PriorStructure.SCALAR
) –if init_prior_prec is scalar, the prior precision is optimized with this structure. otherwise, the structure of init_prior_prec is maintained.
-
val_loader
(DataLoader
, default:None
) –DataLoader for the validation set; each iterate is a training batch (X, y).
-
loss
(callable or Metric
, default:None
) –loss function to use for CV. If callable, the loss is computed offline (memory intensive). If torchmetrics.Metric, running loss is computed (efficient). The default depends on the likelihood:
RunningNLLMetric()
for classification and reward modeling, runningMeanSquaredError()
for regression. -
log_prior_prec_min
(float
, default:-4
) –lower bound of gridsearch interval.
-
log_prior_prec_max
(float
, default:4
) –upper bound of gridsearch interval.
-
grid_size
(int
, default:100
) –number of values to consider inside the gridsearch interval.
-
link_approx
(LinkApprox or str in {'mc', 'probit', 'bridge'}
, default:LinkApprox.PROBIT
) –how to approximate the classification link function for the
'glm'
. Forpred_type='nn'
, only'mc'
is possible. -
n_samples
(int
, default:100
) –number of samples for
link_approx='mc'
. -
verbose
(bool
, default:False
) –if true, the optimized prior precision will be printed (can be a large tensor if the prior has a diagonal covariance).
-
progress_bar
(bool
, default:False
) –whether to show a progress bar; updated at every batch-Hessian computation. Useful for very large model and large amount of data, esp. when
subset_of_weights='all'
.
Source code in laplace/baselaplace.py
342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 |
|
_glm_forward_call
#
_glm_forward_call(x: Tensor | MutableMapping, likelihood: Likelihood | str, joint: bool = False, link_approx: LinkApprox | str = LinkApprox.PROBIT, n_samples: int = 100, diagonal_output: bool = False) -> Tensor | tuple[Tensor, Tensor]
Compute the posterior predictive on input data x
for "glm" pred type.
Parameters:
-
x
(Tensor or MutableMapping
) –(batch_size, input_shape)
if tensor. If MutableMapping, must contain the said tensor. -
likelihood
(Likelihood or str in {'classification', 'regression', 'reward_modeling'}
) –determines the log likelihood Hessian approximation.
-
link_approx
(('mc', 'probit', 'bridge', 'bridge_norm')
, default:'mc'
) –how to approximate the classification link function for the
'glm'
. Forpred_type='nn'
, only 'mc' is possible. -
joint
(bool
, default:False
) –Whether to output a joint predictive distribution in regression with
pred_type='glm'
. If set toTrue
, the predictive distribution has the same form as GP posterior, i.e. N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]). IfFalse
, then only outputs the marginal predictive distribution. Only available for regression and GLM predictive. -
n_samples
(int
, default:100
) –number of samples for
link_approx='mc'
. -
diagonal_output
(bool
, default:False
) –whether to use a diagonalized posterior predictive on the outputs. Only works for
pred_type='glm'
andlink_approx='mc'
.
Returns:
-
predictive
(Tensor or tuple[Tensor]
) –For
likelihood='classification'
, a torch.Tensor is returned with a distribution over classes (similar to a Softmax). Forlikelihood='regression'
, a tuple of torch.Tensor is returned with the mean and the predictive variance. Forlikelihood='regression'
andjoint=True
, a tuple of torch.Tensor is returned with the mean and the predictive covariance.
Source code in laplace/baselaplace.py
570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 |
|
_glm_predictive_samples
#
_glm_predictive_samples(f_mu: Tensor, f_var: Tensor, n_samples: int, diagonal_output: bool = False, generator: Generator | None = None) -> Tensor
Sample from the posterior predictive on input data x
using "glm" prediction
type.
Parameters:
-
f_mu
(Tensor or MutableMapping
) –glm predictive mean
(batch_size, output_shape)
-
f_var
(Tensor or MutableMapping
) –glm predictive covariances
(batch_size, output_shape, output_shape)
-
n_samples
(int
) –number of samples
-
diagonal_output
(bool
, default:False
) –whether to use a diagonalized glm posterior predictive on the outputs.
-
generator
(Generator
, default:None
) –random number generator to control the samples (if sampling used)
Returns:
-
samples
(Tensor
) –samples
(n_samples, batch_size, output_shape)