Tokenizer, Dataset, and "collate_fn"
Tokenizer, Dataset, and “collate_fn”
collate_fn
function and leave all tokenization to collate_fn
. Don’t run your tokenizer
in Dataset
!1 Three Places to Do Tokenization
When you’re training your language model with pytorch
, you have to first tokenize the texts before feeding them to the model. Generally speaking, the tokenization can be done at three different places:
1.1 Place 1: Before the creation of Dataset
You tokenize the texts only once, and then use the results to build the Dataset.
This is only feasible when the input texts are small, e.g, a few hundreds.
1.2 Place 2: In the __getitem__
method of the Dataset
class
class TextDataset(Dataset):
def __getitem__(self, index):
text = self.text[index]
input_tokens = tokenizer(text, padding="max_length", max_length=128, return_tensors="pt")
return input_tokens
Method 2 first does tokenization for every single input (self.text[index]
), and then pads every input to the maximum possible length (128 in this case). As you can image, this is very inefficient since you’re not processing the data in batches.
1.3 Place 3: in the collate_fn
function
class TextDataset(Dataset):
def __getitem__(self, index):
text = self.text[index]
return text
def collate_fn(batch):
# get text
text = batch['text']
# get the input tokens
input_tokens = tokenizer(text, padding=True, truncation=True, return_tensors="pt")
return input_tokens
# init dataloader
dataloader = DataLoader(TextDataloader, batch_size=16, collate_fn=collate_fn)
In Method 3, instead of returning the tokenization result (input_tokens
) in the __getitem__
function, we simply return the raw text. All the tokenization is done in the collate_fu
function. collate_fn
is some sort of a data collector: it collects the results of Dataset
and then assembles them for the use of DataLoader
. So when we run text = batch['text']
in collate_fn
, the result is a list of texts whose length is equal to the batch size. Since we only do tokenization for every batch_size
inputs, the efficiency has significantly been improved.
2 Full Code
Here’s a minimum piece of code to illustrate what I’ve said. The code is based on pytorch-lightning
, but you don’t need to be an expert of pytorch-lightning
to understand it, since our focus is on the data part, not the model and training part.
# The purpose of this file is to test where to put the tokenizer
import os
import torch
from transformers import BertTokenizer, BertModel
from lightning.pytorch import LightningModule, Trainer
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split
# init the tokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
class TextDataset(Dataset):
def __init__(self, text, target):
self.len = len(text)
self.text = text
self.target = target
def __getitem__(self, index):
return (self.text[index], self.target[index])
def __len__(self):
return self.len
def collate_fn(batch):
# unpack batch
text = [_[0] for _ in batch]
target = [_[1] for _ in batch]
# get the input tokens
input_tokens = tokenizer(text, padding=True, truncation=True, return_tensors="pt")
# get the target
target = torch.tensor(target)
return input_tokens, target
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.model = BertModel.from_pretrained("bert-base-uncased")
self.linear = torch.nn.Linear(768, 1)
self.loss = torch.nn.MSELoss()
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx):
# unpack batch
tokens, target = batch
# forward pass
y = self.linear(self.model(**tokens).last_hidden_state).squeeze().sum(-1)
loss = self.loss(y, target)
self.log("train_loss", loss)
return {"loss": loss}
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-3)
def init_inputs(n):
# train/val/test: 7:1:2
text = ["Yes!", "A sentence with more words.", "Come on!"] * n
target = [1.0, 2, 3] * n
return text, target
def run():
# hypyterparameters
batch_size = 32
# initialize the text
train_text, train_target = init_inputs(n=int(1e4))
# init the dataset
train_data = DataLoader(
TextDataset(train_text, train_target),
batch_size=batch_size,
collate_fn=collate_fn,
drop_last=True,
)
model = BoringModel()
trainer = Trainer(
accelerator="gpu",
devices=2,
strategy="ddp_find_unused_parameters_true",
default_root_dir=os.getcwd(),
num_sanity_val_steps=0,
max_epochs=1,
enable_model_summary=True,
)
trainer.fit(model, train_dataloaders=train_data)
if __name__ == "__main__":
run()