The Correct Way to Average Embeddings Using a Mask
A common task in deep learning is to average embeddings in a sequence. For example, the embedding of a sequence (e.g., a sentence) is calculated as the average of the token embeddings it contains. The tricky part is that deep neural networks are usually trained in batches, and sequence lengths in the same batch can differ. For example, consider the following batch of two sentences:
# create a tokenizer
tokenizer = AutoTokenizer.from_pretrained("roberta-base")
# two sentences of different lengths
text = ['hello world', 'What a nice day!']
# get the tokens
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
input_ids = inputs['input_ids']
attention_mask = inputs['attention_mask']
print(input_ids)
# tensor([[ 0, 42891, 232, 2, 1, 1, 1],
# [ 0, 2264, 10, 2579, 183, 328, 2]])
print(attention_mask)
# tensor([[1, 1, 1, 1, 0, 0, 0],
# [1, 1, 1, 1, 1, 1, 1]])
As you can see, the first sentence (‘hello world’) is tokenized into four tokens, where 0
and 2
indicate the sentence start and end tokens. But because it’s in the same batch as a longer sentence, we append three padding tokens 1
to match lengths. The second sentence has no padding since it’s the longest in the batch.
The attention_mask
tensor shows which positions are padding (0
) and which are not (1
).
The goal is to use attention_mask
to compute average embeddings per sentence. Specifically, we only average the first four tokens in sentence one; for sentence two, we use all tokens.
I recommend using einsum
(Einstein summation) from PyTorch:
# create a simple model
model = AutoModel.from_pretrained("roberta-base")
# calculate embeddings
with torch.no_grad():
embs = model(**inputs).last_hidden_state # (2, 7, 768)
#--- How to average the tensor of shape (2, 7, 768) to (2, 768)? ---#
# convert attention_mask to float
attention_mask_float = attention_mask.to(embs.dtype) # (2, 7)
# aggregate using einsum
emb_mean = torch.einsum("BSD,BS->BD", embs, attention_mask_float) # (2, 768)
# divide by the sequence length
# attention_mask_float.sum(dim=1, keepdim=True): (2, 1)
# use torch.clamp to avoid division by zero
emb_mean = emb_mean / torch.clamp(
attention_mask_float.sum(dim=1, keepdim=True), min=1e-9
)
The key line is torch.einsum("BSD,BS->BD", embs, attention_mask_float)
, where einsum
performs row‑wise aggregation. B
, S
, D
represent batch size, sequence length, and embedding dimension, respectively—here (2, 7, 768)
.
The string "BSD,BS->BD"
specifies the operation:
BSD
andBS
are the dimensions ofembs
andattention_mask_float
.->BD
specifies the dimensions of the output tensor.
This computes a weighted sum along the sequence dimension S
of embs
, with weights from attention_mask_float
.
Since we’re calculating the mean, not the sum, we also divide by the number of valid tokens per sentence, performed above with torch.clamp
to avoid zeros.