Module laplace.utils.feature_extractor
Classes
class FeatureReduction (value, names=None, *, module=None, qualname=None, type=None, start=1)
-
Possible choices of feature reduction before applying last-layer Laplace.
Ancestors
- builtins.str
- enum.Enum
Class variables
var PICK_FIRST
var PICK_LAST
var AVERAGE
class FeatureExtractor (model: nn.Module, last_layer_name: str | None = None, enable_backprop: bool = False, feature_reduction: FeatureReduction | str | None = None)
-
Feature extractor for a PyTorch neural network. A wrapper which can return the output of the penultimate layer in addition to the output of the last layer for each forward pass. If the name of the last layer is not known, it can determine it automatically. It assumes that the last layer is linear and that for every forward pass the last layer is the same. If the name of the last layer is known, it can be passed as a parameter at initilization; this is the safest way to use this class. Based on https://gist.github.com/fkodom/27ed045c9051a39102e8bcf4ce31df76.
Parameters
model
:torch.nn.Module
- PyTorch model
last_layer_name
:str
, default=None
- if the name of the last layer is already known, otherwise it will be determined automatically.
enable_backprop
:bool
, default=False
- whether to enable backprop through the feature extactor to get the gradients of the inputs. Useful for e.g. Bayesian optimization.
feature_reduction
:FeatureReduction
orstr
, default=None
- when the last-layer
features
is a tensor of dim >= 3, this tells how to reduce it into a dim-2 tensor. E.g. in LLMs for non-language modeling problems, the penultultimate output is a tensor of shape(batch_size, seq_len, embd_dim)
. But the last layer maps(batch_size, embd_dim)
to(batch_size, n_classes)
. Note: Make sure that this option faithfully reflects the reduction in the model definition. When inputting a string, available options are{'pick_first', 'pick_last', 'average'}
.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Ancestors
- torch.nn.modules.module.Module
Class variables
var dump_patches : bool
var training : bool
var call_super_init : bool
Methods
def forward(self, x: torch.Tensor | MutableMapping[str, torch.Tensor | Any]) ‑> Callable[..., Any]
-
Forward pass. If the last layer is not known yet, it will be determined when this function is called for the first time.
Parameters
x
:torch.Tensor
ora dict-like object containing the input tensors
- one batch of data to use as input for the forward pass
def forward_with_features(self, x: torch.Tensor | MutableMapping[str, torch.Tensor | Any]) ‑> tuple[torch.Tensor, torch.Tensor]
-
Forward pass which returns the output of the penultimate layer along with the output of the last layer. If the last layer is not known yet, it will be determined when this function is called for the first time.
Parameters
x
:torch.Tensor
ora dict-like object containing the input tensors
- one batch of data to use as input for the forward pass
def set_last_layer(self, last_layer_name: str) ‑> None
-
Set the last layer of the model by its name. This sets the forward hook to get the output of the penultimate layer.
Parameters
last_layer_name
:str
- the name of the last layer (fixed in
model.named_modules()
).
def find_last_layer(self, x: torch.Tensor | MutableMapping[str, torch.Tensor | Any]) ‑> torch.Tensor
-
Automatically determines the last layer of the model with one forward pass. It assumes that the last layer is the same for every forward pass and that it is an instance of
torch.nn.Linear
. Might not work with every architecture, but is tested with all PyTorch torchvision classification models (besides SqueezeNet, which has no linear last layer).Parameters
x
:torch.Tensor
ordict-like object containing the input tensors
- one batch of data to use as input for the forward pass