Tokenizer, Dataset, and "collate_fn"

Tokenizer, Dataset, and “collate_fn”

TL;DR
For best efficiency, override the collate_fn function and leave all tokenization to collate_fn. Don’t run your tokenizer in Dataset.

When you’re training a language model with PyTorch, you have to first tokenize texts before feeding them to the model. Generally, tokenization can be done in three places:

Tokenize the texts once, then use the results to build the Dataset. This is only feasible when the input corpus is small, e.g., a few hundred samples.

python

class TextDataset(Dataset):
    def __getitem__(self, index):
        text = self.text[index]
        input_tokens = tokenizer(text, padding="max_length", max_length=128, return_tensors="pt")
        return input_tokens

Method 2 tokenizes every single input (self.text[index]), then pads each to the maximum possible length (128 here). As you can imagine, this is inefficient since you’re not processing in batches.

python

class TextDataset(Dataset):
    def __getitem__(self, index):
        text = self.text[index]
        return text

def collate_fn(batch):
    # get text
    text = [x for x in batch]

    # get the input tokens
    input_tokens = tokenizer(text, padding=True, truncation=True, return_tensors="pt")

    return input_tokens

# init dataloader
dataloader = DataLoader(TextDataset(...), batch_size=16, collate_fn=collate_fn)

In Method 3, instead of returning tokenized results in __getitem__, we return raw text. All tokenization happens in collate_fn. collate_fn collects items from Dataset and assembles them for the DataLoader. When we run text = [x for x in batch] in collate_fn, we get a list of texts of length batch_size. Since tokenization runs once per batch, efficiency improves significantly.

Here’s a minimal example to illustrate the idea. The code is based on PyTorch Lightning, but you don’t need to be an expert; our focus is data, not the model/training.

python

# The purpose of this file is to test where to put the tokenizer

import os
import torch
from transformers import BertTokenizer, BertModel

from lightning.pytorch import LightningModule, Trainer
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split

# init the tokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")


class TextDataset(Dataset):
    def __init__(self, text, target):
        self.len = len(text)
        self.text = text
        self.target = target

    def __getitem__(self, index):
        return (self.text[index], self.target[index])

    def __len__(self):
        return self.len


def collate_fn(batch):
    # unpack batch
    text = [_[0] for _ in batch]
    target = [_[1] for _ in batch]

    # get the input tokens
    input_tokens = tokenizer(text, padding=True, truncation=True, return_tensors="pt")

    # get the target
    target = torch.tensor(target)

    return input_tokens, target


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.model = BertModel.from_pretrained("bert-base-uncased")
        self.linear = torch.nn.Linear(768, 1)
        self.loss = torch.nn.MSELoss()

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        # unpack batch
        tokens, target = batch

        # forward pass
        y = self.linear(self.model(**tokens).last_hidden_state).squeeze().sum(-1)
        loss = self.loss(y, target)

        self.log("train_loss", loss)
        return {"loss": loss}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)


def init_inputs(n):
    # train/val/test: 7:1:2
    text = ["Yes!", "A sentence with more words.", "Come on!"] * n
    target = [1.0, 2, 3] * n

    return text, target


def run():
    # hypyterparameters
    batch_size = 32

    # initialize the text
    train_text, train_target = init_inputs(n=int(1e4))

    # init the dataset
    train_data = DataLoader(
        TextDataset(train_text, train_target),
        batch_size=batch_size,
        collate_fn=collate_fn,
        drop_last=True,
    )

    model = BoringModel()

    trainer = Trainer(
        accelerator="gpu",
        devices=2,
        strategy="ddp_find_unused_parameters_true",
        default_root_dir=os.getcwd(),
        num_sanity_val_steps=0,
        max_epochs=1,
        enable_model_summary=True,
    )
    trainer.fit(model, train_dataloaders=train_data)


if __name__ == "__main__":
    run()
Nickname
Email
Website
0/500
  • OωO
  • |´・ω・)ノ
  • ヾ(≧∇≦*)ゝ
  • (☆ω☆)
  • (╯‵□′)╯︵┴─┴
  •  ̄﹃ ̄
  • (/ω\)
  • ∠( ᐛ 」∠)_
  • (๑•̀ㅁ•́ฅ)
  • →_→
  • ୧(๑•̀⌄•́๑)૭
  • ٩(ˊᗜˋ*)و
  • (ノ°ο°)ノ
  • (´இ皿இ`)
  • ⌇●﹏●⌇
  • (ฅ´ω`ฅ)
  • (╯°A°)╯︵○○○
  • φ( ̄∇ ̄o)
  • ヾ(´・ ・`。)ノ"
  • ( ง ᵒ̌皿ᵒ̌)ง⁼³₌₃
  • (ó﹏ò。)
  • Σ(っ °Д °;)っ
  • ( ,,´・ω・)ノ"(´っω・`。)
  • ╮(╯▽╰)╭
  • o(*////▽////*)q
  • >﹏<
  • ( ๑´•ω•) "(ㆆᴗㆆ)
  • 😂
  • 😀
  • 😅
  • 😊
  • 🙂
  • 🙃
  • 😌
  • 😍
  • 😘
  • 😜
  • 😝
  • 😏
  • 😒
  • 🙄
  • 😳
  • 😡
  • 😔
  • 😫
  • 😱
  • 😭
  • 💩
  • 👻
  • 🙌
  • 🖕
  • 👍
  • 👫
  • 👬
  • 👭
  • 🌚
  • 🌝
  • 🙈
  • 💊
  • 😶
  • 🙏
  • 🍦
  • 🍉
  • 😣
  • 颜文字
  • Emoji
  • Bilibili
0 comments
No comment