laplace.baselaplace
#
FunctionalLaplace
#
FunctionalLaplace(model: Module, likelihood: Likelihood | str, n_subset: int, sigma_noise: float | Tensor = 1.0, prior_precision: float | Tensor = 1.0, prior_mean: float | Tensor = 0.0, temperature: float = 1.0, enable_backprop: bool = False, dict_key_x='input_ids', dict_key_y='labels', backend: type[CurvatureInterface] | None = BackPackGGN, backend_kwargs: dict[str, Any] | None = None, independent_outputs: bool = False, seed: int = 0)
Bases: BaseLaplace
Applying the GGN (Generalized Gauss-Newton) approximation for the Hessian in the Laplace approximation of the posterior turns the underlying probabilistic model from a BNN into a GLM (generalized linear model). This GLM (in the weight space) is equivalent to a GP (in the function space), see Approximate Inference Turns Deep Networks into Gaussian Processes (Khan et al., 2019)
This class implements the (approximate) GP inference through which we obtain the desired quantities (posterior predictive, marginal log-likelihood). See Improving predictions of Bayesian neural nets via local linearization (Immer et al., 2021) for more details.
Note that for likelihood='classification'
, we approximate \( L_{NN} \) with a diagonal matrix
( \( L_{NN} \) is a block-diagonal matrix, where blocks represent Hessians of per-data-point log-likelihood w.r.t.
neural network output \( f \), See Appendix A.2.1 for exact definition). We
resort to such an approximation because of the (possible) errors found in Laplace approximation for
multiclass GP classification in Chapter 3.5 of R&W 2006 GP book,
see the question
here
for more details. Alternatively, one could also resort to one-vs-one or one-vs-rest implementations
for multiclass classification, however, that is not (yet) supported here.
Parameters:
-
num_data
(int
) –number of data points for Subset-of-Data (SOD) approximate GP inference.
-
diagonal_kernel
(bool
) –GP kernel here is product of Jacobians, which results in a \( C \times C\) matrix where \(C\) is the output dimension. If
diagonal_kernel=True
, only a diagonal of a GP kernel is used. This is (somewhat) equivalent to assuming independent GPs across output channels. -
See
–
Source code in laplace/baselaplace.py
log_likelihood
#
Compute log likelihood on the training data after .fit()
has been called.
The log likelihood is computed on-demand based on the loss and, for example,
the observation noise which makes it differentiable in the latter for
iterative updates.
Returns:
-
log_likelihood
(Tensor
) –
prior_precision_diag
#
Obtain the diagonal prior precision \(p_0\) constructed from either a scalar, layer-wise, or diagonal prior precision.
Returns:
-
prior_precision_diag
(Tensor
) –
log_det_ratio
#
Computes log determinant term in GP marginal likelihood
For classification
we use eq. (3.44) from Chapter 3.5 from
GP book R&W 2006 with
(note that we always use diagonal approximation \(D\) of the Hessian of log likelihood w.r.t. \(f\)):
log determinant term := \( \log | I + D^{1/2}K D^{1/2} | \)
For regression
, we use "standard" GP marginal likelihood:
log determinant term := \( \log | K + \sigma_2 I | \)
scatter
#
Compute scatter term in GP log marginal likelihood.
For classification
we use eq. (3.44) from Chapter 3.5 from
GP book R&W 2006 with \(\hat{f} = f \):
scatter term := \( f K^{-1} f^{T} \)
For regression
, we use "standard" GP marginal likelihood:
scatter term := \( (y - m)K^{-1}(y -m )^T \), where \( m \) is the mean of the GP prior, which in our case corresponds to \( m := f + J (\theta - \theta_{MAP}) \)
_glm_forward_call
#
_glm_forward_call(x: Tensor | MutableMapping, likelihood: Likelihood | str, joint: bool = False, link_approx: LinkApprox | str = LinkApprox.PROBIT, n_samples: int = 100, diagonal_output: bool = False) -> Tensor | tuple[Tensor, Tensor]
Compute the posterior predictive on input data x
for "glm" pred type.
Parameters:
-
x
(Tensor or MutableMapping
) –(batch_size, input_shape)
if tensor. If MutableMapping, must contain the said tensor. -
likelihood
(Likelihood or str in {'classification', 'regression', 'reward_modeling'}
) –determines the log likelihood Hessian approximation.
-
link_approx
(('mc', 'probit', 'bridge', 'bridge_norm')
, default:'mc'
) –how to approximate the classification link function for the
'glm'
. Forpred_type='nn'
, only 'mc' is possible. -
joint
(bool
, default:False
) –Whether to output a joint predictive distribution in regression with
pred_type='glm'
. If set toTrue
, the predictive distribution has the same form as GP posterior, i.e. N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]). IfFalse
, then only outputs the marginal predictive distribution. Only available for regression and GLM predictive. -
n_samples
(int
, default:100
) –number of samples for
link_approx='mc'
. -
diagonal_output
(bool
, default:False
) –whether to use a diagonalized posterior predictive on the outputs. Only works for
pred_type='glm'
andlink_approx='mc'
.
Returns:
-
predictive
(Tensor or tuple[Tensor]
) –For
likelihood='classification'
, a torch.Tensor is returned with a distribution over classes (similar to a Softmax). Forlikelihood='regression'
, a tuple of torch.Tensor is returned with the mean and the predictive variance. Forlikelihood='regression'
andjoint=True
, a tuple of torch.Tensor is returned with the mean and the predictive covariance.
Source code in laplace/baselaplace.py
570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 |
|
_glm_predictive_samples
#
_glm_predictive_samples(f_mu: Tensor, f_var: Tensor, n_samples: int, diagonal_output: bool = False, generator: Generator | None = None) -> Tensor
Sample from the posterior predictive on input data x
using "glm" prediction
type.
Parameters:
-
f_mu
(Tensor or MutableMapping
) –glm predictive mean
(batch_size, output_shape)
-
f_var
(Tensor or MutableMapping
) –glm predictive covariances
(batch_size, output_shape, output_shape)
-
n_samples
(int
) –number of samples
-
diagonal_output
(bool
, default:False
) –whether to use a diagonalized glm posterior predictive on the outputs.
-
generator
(Generator
, default:None
) –random number generator to control the samples (if sampling used)
Returns:
-
samples
(Tensor
) –samples
(n_samples, batch_size, output_shape)
Source code in laplace/baselaplace.py
_check_prior_precision
#
_check_prior_precision(prior_precision: float | Tensor)
Checks if the given prior precision is suitable for the GP interpretation of LLA. As such, only single value priors, i.e., isotropic priors are suitable.
Source code in laplace/baselaplace.py
_init_K_MM
#
Allocates memory for the kernel matrix evaluated at the subset of the training data points. If the subset is of size \(M\) and the problem has \(C\) outputs, this is a list of C \((M,M\)) tensors for diagonal kernel and \((M x C, M x C)\) otherwise.
Source code in laplace/baselaplace.py
_init_Sigma_inv
#
Allocates memory for the cholesky decomposition of [ K_{MM} + \Lambda_{MM}^{-1}. ] See See Improving predictions of Bayesian neural nets via local linearization (Immer et al., 2021) Equation 15 for more information.
Source code in laplace/baselaplace.py
_store_K_batch
#
Given the kernel matrix between the i-th and the j-th batch, stores it in the corresponding position in self.K_MM.
Source code in laplace/baselaplace.py
_build_L
#
_build_L(lambdas: list[Tensor])
Given a list of the Hessians of per-batch log-likelihood w.r.t. neural network output \( f \), returns the contatenation of these hessians in a suitable format for the used kernel (diagonal or not).
In this function the diagonal approximation is performed. Please refer to the introduction of the class for more details.
Parameters:
-
lambdas
(list of torch.Tensor of shape (C, C)
) –Contains per-batch log-likelihood w.r.t. neural network output \( f \).
Returns:
-
L
(list with length C of tensors with shape M or tensor (MxC)
) –Contains the given Hessians in a suitable format.
Source code in laplace/baselaplace.py
_build_Sigma_inv
#
Computes the cholesky decomposition of [ K_{MM} + \Lambda_{MM}^{-1}. ] See See Improving predictions of Bayesian neural nets via local linearization (Immer et al., 2021) Equation 15 for more information.
<<<<<<< HEAD As the diagonal approximation is performed with \Lambda_{MM} (which is stored in self.L), ======= As the diagonal approximation is performed with \(\Lambda_{MM}\) (which is stored in self.L),
main the code is greatly simplified.
Source code in laplace/baselaplace.py
_get_SoD_data_loader
#
Subset-of-Datapoints data loader
Source code in laplace/baselaplace.py
fit
#
fit(train_loader: DataLoader | MutableMapping, progress_bar: bool = False)
Fit the Laplace approximation of a GP posterior.
Parameters:
-
train_loader
(DataLoader
) –train_loader.dataset
needs to be set to access \(N\), size of the data settrain_loader.batch_size
needs to be set to access \(b\) batch_size -
progress_bar
(bool
, default:False
) –whether to show a progress bar during the fitting process.
Source code in laplace/baselaplace.py
2179 2180 2181 2182 2183 2184 2185 2186 2187 2188 2189 2190 2191 2192 2193 2194 2195 2196 2197 2198 2199 2200 2201 2202 2203 2204 2205 2206 2207 2208 2209 2210 2211 2212 2213 2214 2215 2216 2217 2218 2219 2220 2221 2222 2223 2224 2225 2226 2227 2228 2229 2230 2231 2232 2233 2234 2235 2236 2237 2238 2239 2240 2241 2242 2243 2244 2245 2246 2247 2248 2249 2250 2251 2252 2253 2254 2255 2256 2257 2258 2259 2260 2261 2262 2263 2264 2265 2266 2267 2268 2269 2270 2271 2272 2273 2274 2275 2276 2277 2278 2279 2280 2281 2282 2283 2284 2285 2286 2287 2288 2289 2290 2291 |
|
__call__
#
__call__(x: Tensor | MutableMapping, pred_type: PredType | str = PredType.GP, joint: bool = False, link_approx: LinkApprox | str = LinkApprox.PROBIT, n_samples: int = 100, diagonal_output: bool = False, generator: Generator | None = None, fitting: bool = False, **model_kwargs: dict[str, Any]) -> Tensor | tuple[Tensor, Tensor]
Compute the posterior predictive on input data x
.
Parameters:
-
x
(Tensor or MutableMapping
) –(batch_size, input_shape)
if tensor. If MutableMapping, must contain the said tensor. -
pred_type
('gp'
, default:'gp'
) –type of posterior predictive, linearized GLM predictive (GP). The GP predictive is consistent with the curvature approximations used here.
-
link_approx
(('mc', 'probit', 'bridge', 'bridge_norm')
, default:'mc'
) –how to approximate the classification link function for the
'glm'
. -
joint
(bool
, default:False
) –Whether to output a joint predictive distribution in regression with
pred_type='glm'
. If set toTrue
, the predictive distribution has the same form as GP posterior, i.e. N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]). IfFalse
, then only outputs the marginal predictive distribution. Only available for regression and GLM predictive. -
n_samples
(int
, default:100
) –number of samples for
link_approx='mc'
. -
diagonal_output
(bool
, default:False
) –whether to use a diagonalized posterior predictive on the outputs. Only works for
link_approx='mc'
. -
generator
(Generator
, default:None
) –random number generator to control the samples (if sampling used).
-
fitting
(bool
, default:False
) –whether or not this predictive call is done during fitting. Only useful for reward modeling: the likelihood is set to
"regression"
whenFalse
and"classification"
whenTrue
.
Returns:
-
predictive
(Tensor or Tuple[Tensor]
) –For
likelihood='classification'
, a torch.Tensor is returned with a distribution over classes (similar to a Softmax). Forlikelihood='regression'
, a tuple of torch.Tensor is returned with the mean and the predictive variance. Forlikelihood='regression'
andjoint=True
, a tuple of torch.Tensor is returned with the mean and the predictive covariance.
Source code in laplace/baselaplace.py
2309 2310 2311 2312 2313 2314 2315 2316 2317 2318 2319 2320 2321 2322 2323 2324 2325 2326 2327 2328 2329 2330 2331 2332 2333 2334 2335 2336 2337 2338 2339 2340 2341 2342 2343 2344 2345 2346 2347 2348 2349 2350 2351 2352 2353 2354 2355 2356 2357 2358 2359 2360 2361 2362 2363 2364 2365 2366 2367 2368 2369 2370 2371 2372 2373 2374 2375 2376 2377 2378 2379 2380 2381 2382 2383 2384 2385 2386 2387 2388 2389 2390 2391 2392 2393 2394 2395 2396 2397 2398 2399 2400 2401 |
|
predictive_samples
#
predictive_samples(x: Tensor | MutableMapping[str, Tensor | Any], pred_type: PredType | str = PredType.GLM, n_samples: int = 100, diagonal_output: bool = False, generator: Generator | None = None) -> Tensor
Sample from the posterior predictive on input data x
.
Can be used, for example, for Thompson sampling.
Parameters:
-
x
(Tensor or MutableMapping
) –input data
(batch_size, input_shape)
-
pred_type
('glm'
, default:'glm'
) –type of posterior predictive, linearized GLM predictive.
-
n_samples
(int
, default:100
) –number of samples
-
diagonal_output
(bool
, default:False
) –whether to use a diagonalized glm posterior predictive on the outputs. Only applies when
pred_type='glm'
. -
generator
(Generator
, default:None
) –random number generator to control the samples (if sampling used)
Returns:
-
samples
(Tensor
) –samples
(n_samples, batch_size, output_shape)
Source code in laplace/baselaplace.py
functional_variance
#
GP posterior variance:
Parameters:
-
Js_star
(torch.Tensor of shape (N*, C, P)
) –Jacobians of test data points
Returns:
-
f_var
(torch.Tensor of shape (N*,C, C)
) –Contains the posterior variances of N* testing points.
Source code in laplace/baselaplace.py
functional_covariance
#
GP posterior covariance:
Parameters:
-
Js_star
(torch.Tensor of shape (N*, C, P)
) –Jacobians of test data points
Returns:
-
f_var
(torch.Tensor of shape (N*xC, N*xC)
) –Contains the posterior covariances of N* testing points.
Source code in laplace/baselaplace.py
_build_K_star_M
#
_build_K_star_M(K_M_star: Tensor, joint: bool = False) -> Tensor
Computes K_{M} (K_{MM}+ L_{MM}^{-1})^{-1} K_{M} given K_{M*}.
Parameters:
-
K_M_star
(list of torch.Tensor
) –Contains K_{M*}. Tensors have shape (N_test, C, C) or (N_test, C) for diagonal kernel.
-
joint
(boolean
, default:False
) –Wether to compute cross covariances or not.
Returns:
-
torch.tensor of shape (N_test, N_test, C) for joint diagonal,
– -
(N_test, C) for non-joint diagonal, (N_test, N_test, C, C) for
– -
joint non-diagonal and (N_test, C, C) for non-joint non-diagonal.
–
Source code in laplace/baselaplace.py
optimize_prior_precision
#
optimize_prior_precision(pred_type: PredType | str = PredType.GP, method: TuningMethod | str = TuningMethod.MARGLIK, n_steps: int = 100, lr: float = 0.1, init_prior_prec: float | Tensor = 1.0, prior_structure: PriorStructure | str = PriorStructure.SCALAR, val_loader: DataLoader | None = None, loss: Metric | Callable[[Tensor], Tensor | float] | None = None, log_prior_prec_min: float = -4, log_prior_prec_max: float = 4, grid_size: int = 100, link_approx: LinkApprox | str = LinkApprox.PROBIT, n_samples: int = 100, verbose: bool = False, progress_bar: bool = False) -> None
optimize_prior_precision_base
from BaseLaplace
with pred_type='gp'
Source code in laplace/baselaplace.py
_kernel_batch
#
Compute K_bb, which is part of K_MM kernel matrix.
Parameters:
-
jacobians
(Tensor(b, C, P)
) – -
batch
(Tensor(b, C)
) –
Returns:
-
kernel
(tensor
) –K_bb with shape (b * C, b * C)
Source code in laplace/baselaplace.py
_kernel_star
#
_kernel_star(jacobians: Tensor, joint: bool = False) -> Tensor
Compute K_star_star kernel matrix.
Parameters:
-
jacobians
(Tensor(b, C, P)
) –
Returns:
-
kernel
(tensor
) –K_star with shape (b, C, C)
Source code in laplace/baselaplace.py
_kernel_batch_star
#
Compute K_b_star, which is a part of K_M_star kernel matrix.
Parameters:
-
jacobians
(Tensor(b1, C, P)
) – -
batch
(Tensor(b2, C)
) –
Returns:
-
kernel
(tensor
) –K_batch_star with shape (b1, b2, C, C)
Source code in laplace/baselaplace.py
_jacobians
#
A wrapper function to compute jacobians - this enables reusing same kernel methods (kernel_batch etc.) in FunctionalLaplace and FunctionalLLLaplace by simply overwriting this method instead of all kernel methods.
Source code in laplace/baselaplace.py
_mean_scatter_term_batch
#
Compute mean vector in the scatter term in the log marginal likelihood
See scatter_lml
property above for the exact equations of mean vectors in scatter terms for
both types of likelihood (regression, classification).
Parameters:
-
Js
(tensor
) –Jacobians (batch, output_shape, parameters)
-
f
(tensor
) –NN output (batch, output_shape)
-
y
(Tensor
) –data labels (batch, output_shape)
Returns:
-
mu
(tensor
) –K_batch_star with shape (batch, output_shape)
Source code in laplace/baselaplace.py
log_marginal_likelihood
#
log_marginal_likelihood(prior_precision: Tensor | None = None, sigma_noise: Tensor | None = None) -> Tensor
Compute the Laplace approximation to the log marginal likelihood.
Requires that the Laplace approximation has been fit before.
The resulting torch.Tensor is differentiable in prior_precision
and
sigma_noise
if these have gradients enabled.
By passing prior_precision
or sigma_noise
, the current value is
overwritten. This is useful for iterating on the log marginal likelihood.
Parameters:
-
prior_precision
(Tensor
, default:None
) –prior precision if should be changed from current
prior_precision
value -
sigma_noise
(Tensor
, default:None
) –observation noise standard deviation if should be changed
Returns:
-
log_marglik
(Tensor
) –