Tokenizer, Dataset, and "collate_fn"

Tokenizer, Dataset, and “collate_fn”

TL;DR
For the best efficiency, override the collate_fn function and leave all tokenization to collate_fn. Don’t run your tokenizer in Dataset!

When you’re training your language model with pytorch, you have to first tokenize the texts before feeding them to the model. Generally speaking, the tokenization can be done at three different places:

You tokenize the texts only once, and then use the results to build the Dataset. This is only feasible when the input texts are small, e.g, a few hundreds.

python

class TextDataset(Dataset):
    def __getitem__(self, index):
        text = self.text[index]
        input_tokens = tokenizer(text, padding="max_length", max_length=128, return_tensors="pt")
        return input_tokens

Method 2 first does tokenization for every single input (self.text[index]), and then pads every input to the maximum possible length (128 in this case). As you can image, this is very inefficient since you’re not processing the data in batches.

python

class TextDataset(Dataset):
    def __getitem__(self, index):
        text = self.text[index]
        return text

def collate_fn(batch):
    # get text
    text = batch['text']

    # get the input tokens
    input_tokens = tokenizer(text, padding=True, truncation=True, return_tensors="pt")

    return input_tokens

# init dataloader
dataloader = DataLoader(TextDataloader, batch_size=16, collate_fn=collate_fn)

In Method 3, instead of returning the tokenization result (input_tokens) in the __getitem__ function, we simply return the raw text. All the tokenization is done in the collate_fu function. collate_fn is some sort of a data collector: it collects the results of Dataset and then assembles them for the use of DataLoader. So when we run text = batch['text'] in collate_fn, the result is a list of texts whose length is equal to the batch size. Since we only do tokenization for every batch_size inputs, the efficiency has significantly been improved.

Here’s a minimum piece of code to illustrate what I’ve said. The code is based on pytorch-lightning, but you don’t need to be an expert of pytorch-lightning to understand it, since our focus is on the data part, not the model and training part.

python

# The purpose of this file is to test where to put the tokenizer

import os
import torch
from transformers import BertTokenizer, BertModel

from lightning.pytorch import LightningModule, Trainer
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split

# init the tokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")


class TextDataset(Dataset):
    def __init__(self, text, target):
        self.len = len(text)
        self.text = text
        self.target = target

    def __getitem__(self, index):
        return (self.text[index], self.target[index])

    def __len__(self):
        return self.len


def collate_fn(batch):
    # unpack batch
    text = [_[0] for _ in batch]
    target = [_[1] for _ in batch]

    # get the input tokens
    input_tokens = tokenizer(text, padding=True, truncation=True, return_tensors="pt")

    # get the target
    target = torch.tensor(target)

    return input_tokens, target


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.model = BertModel.from_pretrained("bert-base-uncased")
        self.linear = torch.nn.Linear(768, 1)
        self.loss = torch.nn.MSELoss()

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        # unpack batch
        tokens, target = batch

        # forward pass
        y = self.linear(self.model(**tokens).last_hidden_state).squeeze().sum(-1)
        loss = self.loss(y, target)

        self.log("train_loss", loss)
        return {"loss": loss}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)


def init_inputs(n):
    # train/val/test: 7:1:2
    text = ["Yes!", "A sentence with more words.", "Come on!"] * n
    target = [1.0, 2, 3] * n

    return text, target


def run():
    # hypyterparameters
    batch_size = 32

    # initialize the text
    train_text, train_target = init_inputs(n=int(1e4))

    # init the dataset
    train_data = DataLoader(
        TextDataset(train_text, train_target),
        batch_size=batch_size,
        collate_fn=collate_fn,
        drop_last=True,
    )

    model = BoringModel()

    trainer = Trainer(
        accelerator="gpu",
        devices=2,
        strategy="ddp_find_unused_parameters_true",
        default_root_dir=os.getcwd(),
        num_sanity_val_steps=0,
        max_epochs=1,
        enable_model_summary=True,
    )
    trainer.fit(model, train_dataloaders=train_data)


if __name__ == "__main__":
    run()