Tokenizer, Dataset, and "collate_fn"
Tokenizer, Dataset, and “collate_fn”
collate_fn function and leave all tokenization to collate_fn. Don’t run your tokenizer in Dataset.1 Three places to do tokenization
When you’re training a language model with PyTorch, you have to first tokenize texts before feeding them to the model. Generally, tokenization can be done in three places:
1.1 Place 1: Before creating the Dataset
Tokenize the texts once, then use the results to build the Dataset. This is only feasible when the input corpus is small, e.g., a few hundred samples.
1.2 Place 2: In the __getitem__ method of the Dataset class
class TextDataset(Dataset):
def __getitem__(self, index):
text = self.text[index]
input_tokens = tokenizer(text, padding="max_length", max_length=128, return_tensors="pt")
return input_tokensMethod 2 tokenizes every single input (self.text[index]), then pads each to the maximum possible length (128 here). As you can imagine, this is inefficient since you’re not processing in batches.
1.3 Place 3: In the collate_fn function
class TextDataset(Dataset):
def __getitem__(self, index):
text = self.text[index]
return text
def collate_fn(batch):
# get text
text = [x for x in batch]
# get the input tokens
input_tokens = tokenizer(text, padding=True, truncation=True, return_tensors="pt")
return input_tokens
# init dataloader
dataloader = DataLoader(TextDataset(...), batch_size=16, collate_fn=collate_fn)In Method 3, instead of returning tokenized results in __getitem__, we return raw text. All tokenization happens in collate_fn. collate_fn collects items from Dataset and assembles them for the DataLoader. When we run text = [x for x in batch] in collate_fn, we get a list of texts of length batch_size. Since tokenization runs once per batch, efficiency improves significantly.
2 Full code
Here’s a minimal example to illustrate the idea. The code is based on PyTorch Lightning, but you don’t need to be an expert; our focus is data, not the model/training.
# The purpose of this file is to test where to put the tokenizer
import os
import torch
from transformers import BertTokenizer, BertModel
from lightning.pytorch import LightningModule, Trainer
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split
# init the tokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
class TextDataset(Dataset):
def __init__(self, text, target):
self.len = len(text)
self.text = text
self.target = target
def __getitem__(self, index):
return (self.text[index], self.target[index])
def __len__(self):
return self.len
def collate_fn(batch):
# unpack batch
text = [_[0] for _ in batch]
target = [_[1] for _ in batch]
# get the input tokens
input_tokens = tokenizer(text, padding=True, truncation=True, return_tensors="pt")
# get the target
target = torch.tensor(target)
return input_tokens, target
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.model = BertModel.from_pretrained("bert-base-uncased")
self.linear = torch.nn.Linear(768, 1)
self.loss = torch.nn.MSELoss()
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx):
# unpack batch
tokens, target = batch
# forward pass
y = self.linear(self.model(**tokens).last_hidden_state).squeeze().sum(-1)
loss = self.loss(y, target)
self.log("train_loss", loss)
return {"loss": loss}
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-3)
def init_inputs(n):
# train/val/test: 7:1:2
text = ["Yes!", "A sentence with more words.", "Come on!"] * n
target = [1.0, 2, 3] * n
return text, target
def run():
# hypyterparameters
batch_size = 32
# initialize the text
train_text, train_target = init_inputs(n=int(1e4))
# init the dataset
train_data = DataLoader(
TextDataset(train_text, train_target),
batch_size=batch_size,
collate_fn=collate_fn,
drop_last=True,
)
model = BoringModel()
trainer = Trainer(
accelerator="gpu",
devices=2,
strategy="ddp_find_unused_parameters_true",
default_root_dir=os.getcwd(),
num_sanity_val_steps=0,
max_epochs=1,
enable_model_summary=True,
)
trainer.fit(model, train_dataloaders=train_data)
if __name__ == "__main__":
run()



















