Example: Huggingface LLMs
In this example, we will see how to apply Laplace on a GPT2 Huggingface (HF) model.
Laplace only has lightweight requirements for this; namely that the model's forward
method must only take a single dict-like object (dict
, UserDict
, or in general,
collections.abc.MutableMapping
). This is entirely compatible with HF since HF's
data loaders are assumed to emit an object derived from UserDict
. However, you
need to ensure this yourself --- you need to wrap the standard HF model to conform
to that requirement. Also, you need to e.g. do torch.to(device)
inside the
said forward
method.
Let's start with as usual with importing stuff.
from collections.abc import MutableMapping
from collections import UserDict
import numpy
import torch
from torch import nn
import torch.utils.data as data_utils
from laplace import Laplace
import logging
import warnings
logging.basicConfig(level="ERROR")
warnings.filterwarnings("ignore")
from transformers import ( # noqa: E402
GPT2Config,
GPT2ForSequenceClassification,
GPT2Tokenizer,
DataCollatorWithPadding,
PreTrainedTokenizer,
)
from peft import LoraConfig, get_peft_model # noqa: E402
from datasets import Dataset # noqa: E402
# make deterministic
torch.manual_seed(0)
numpy.random.seed(0)
Next, we create a toy dataset. You can use any HF datasets or your own, of course.
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token_id = tokenizer.eos_token_id
data = [
{"text": "Today is hot, but I will manage!!!!", "label": 1},
{"text": "Tomorrow is cold", "label": 0},
{"text": "Carpe diem", "label": 1},
{"text": "Tempus fugit", "label": 1},
]
dataset = Dataset.from_list(data)
def tokenize(row):
return tokenizer(row["text"])
dataset = dataset.map(tokenize, remove_columns=["text"])
dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
dataloader = data_utils.DataLoader(
dataset, batch_size=100, collate_fn=DataCollatorWithPadding(tokenizer)
)
data = next(iter(dataloader))
print(
f"Huggingface data defaults to UserDict, which is a MutableMapping? {isinstance(data, UserDict)}"
)
for k, v in data.items():
print(k, v.shape)
This is the output:
Huggingface data defaults to UserDict, which is a MutableMapping? True
input_ids torch.Size([4, 9])
attention_mask torch.Size([4, 9])
labels torch.Size([4])
Laplace on a subset of an LLM's weights#
Now, let's do the main "meat" of this example: Wrapping the HF model into a model that is
compatible with Laplace. Notice that this wrapper just wraps the HF model and nothing else.
Notice also we do inputs.to(device)
inside self.forward()
.
class MyGPT2(nn.Module):
"""
Huggingface LLM wrapper.
Args:
tokenizer: The tokenizer used for preprocessing the text data. Needed
since the model needs to know the padding token id.
"""
def __init__(self, tokenizer: PreTrainedTokenizer) -> None:
super().__init__()
config = GPT2Config.from_pretrained("gpt2")
config.pad_token_id = tokenizer.pad_token_id
config.num_labels = 2
self.hf_model = GPT2ForSequenceClassification.from_pretrained(
"gpt2", config=config
)
def forward(self, data: MutableMapping) -> torch.Tensor:
"""
Custom forward function. Handles things like moving the
input tensor to the correct device inside.
Args:
data: A dict-like data structure with `input_ids` inside.
This is the default data structure assumed by Huggingface
dataloaders.
Returns:
logits: An `(batch_size, n_classes)`-sized tensor of logits.
"""
device = next(self.parameters()).device
input_ids = data["input_ids"].to(device)
attn_mask = data["attention_mask"].to(device)
output_dict = self.hf_model(input_ids=input_ids, attention_mask=attn_mask)
return output_dict.logits
model = MyGPT2(tokenizer)
Now, let's apply Laplace. Let's do a last-layer Laplace first.
Notice that we add
an argument feature_reduction
there. This is because Huggingface models reduce the
logits and not the features.
model = MyGPT2(tokenizer)
model.eval()
la = Laplace(
model,
likelihood="classification",
subset_of_weights="last_layer",
hessian_structure="full",
# This must reflect faithfully the reduction technique used in the model
# Otherwise, correctness is not guaranteed
feature_reduction="pick_last",
)
la.fit(dataloader)
la.optimize_prior_precision()
X_test = next(iter(dataloader))
print(f"[Last-layer Laplace] The predictive tensor is of shape: {la(X_test).shape}.")
Here's the output:
Subnetwork Laplace#
Also, we can do the same thing by switching off the gradients of all layers except the
top layer. Laplace will automatically only compute the Hessian (and Jacobians) of the
parameters in which requires_grad
is True
.
Notice that you can "mix-and-match" this gradient switching. You can do a subnetwork Laplace easily by doing so!
model.eval()
# Enable grad only for the last layer
for p in model.hf_model.parameters():
p.requires_grad = False
for p in model.hf_model.score.parameters():
p.requires_grad = True
la = Laplace(
model,
# Will only hit the last-layer since it's the only one that is grad-enabled
likelihood="classification",
subset_of_weights="all",
hessian_structure="diag",
)
la.fit(dataloader)
la.optimize_prior_precision()
X_test = next(iter(dataloader))
print(f"[Subnetwork Laplace] The predictive tensor is of shape: {la(X_test).shape}.")
Here are the outputs to validate that Laplace works:
Full Laplace on LoRA parameters only#
Of course, you can also apply Laplace on the parameter-efficient fine tuning weights (like LoRA).
To do this, simply extend your LLM with LoRA, using HF's peft
library, and apply Laplace as
usual. Note that peft
automatically switches off the non-LoRA weights.
def get_lora_model():
model = MyGPT2(tokenizer) # Note we don't disable grad
config = LoraConfig(
r=4,
lora_alpha=16,
target_modules=["c_attn"], # LoRA on the attention weights
lora_dropout=0.1,
bias="none",
)
lora_model = get_peft_model(model, config)
return lora_model
lora_model = get_lora_model()
# Train it as usual
lora_model.eval()
lora_la = Laplace(
lora_model,
likelihood="classification",
subset_of_weights="all",
hessian_structure="kron",
)
lora_la.fit(dataloader)
X_test = next(iter(dataloader))
print(f"[LoRA-LLM] The predictive tensor is of shape: {lora_la(X_test).shape}.")
Here is the output, as expected:
As a final note, the dict-like input requirement of Laplace is very flexible. It can essentially be applicable to any tasks and any models. You just need to wrap the said model and make sure that your data loaders emit dict-like objects, where the input tensors are the dicts' values.
Caveats#
Currently, diagonal EF with the Curvlinops backend is unsupported for dict-based inputs.
This is because we use torch.func
's vmap
to compute the diag-EF, and it only accepts
tensor input in the model's forward
.
See this issue.
So, if you can write down your Huggingface model's forward
to accept only a single tensor,
this is much preferable.
For instance, in the case of causal LLM like GPTs, only input_ids
tensor is necessary.
Then, any backend and any hessian factorization can be used in this case.
Otherwise, if you must use dict-based inputs, choose the following backends:
CurvlinopsGGN
forhessian_factorization = {"kron", "diag"}
CurvlinopsEF
forhessian_factorization = {"kron"}
AsdlGGN
forhessian_factorization = {"kron", "diag"}
AsdlEF
forhessian_factorization = {"kron", "diag"}