The Correct Way to Average Embeddings Using a Mask
A common task in deep learning is to average embeddings in a sequence. For example, the embedding of a sequence (e.g., a sentence) is calculated as the average of the token embeddings in it. What makes this problem tricky is that deep neural networks are usually trained in batches, and the sequence lengths in the same batch are different. For example, consider the following batch of two sentences:
# create a tokenizer
tokenizer = AutoTokenizer.from_pretrained("roberta-base")
# two sentences of different lengths
text = ['hello world', 'What a nice day!']
# get the tokens
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
input_ids = inputs['input_ids']
attention_mask = inputs['attention_mask']
print(input_ids)
# tensor([[ 0, 42891, 232, 2, 1, 1, 1],
# [ 0, 2264, 10, 2579, 183, 328, 2]])
print(attention_mask)
# tensor([[1, 1, 1, 1, 0, 0, 0],
# [1, 1, 1, 1, 1, 1, 1]])
As you can see, the first sentence (‘hello world’) is tokenized into four tokens, where 0
and 2
indicate the sentence start and ending tokens. But because it’s in the same batch with a longer sentence, we append three padding tokens 1
to match the lengths of the two. Another thing you’ll notice is that the second sentence has no padding tokens since it’s the longest in the batch.
The attention_mask
tensor shows which is padding (0
) and which is not (1
).
The goal is to use attention_mask
to calculate average embeddings in each sentence. Specifically, we only add up the first four tokens in sentence one. For sentence two, we use all the tokens.
I recommend using the einsum
(Einstein Summation) offered in pytorch:
# create a simple model
model = AutoModel.from_pretrained("roberta-base")
# calculate embeddings
with torch.no_grad():
embs = model(**inputs).last_hidden_state # (2, 7, 768)
#--- How to average the tensor of shape (2, 7, 768) to (2, 768)? ---#
# convert attention_mask to float
attention_mask_float = attention_mask.to(embs.dtype) # (2, 7)
# aggregate using Einsum
emb_mean = torch.einsum("BSD,BS->BD", embs, attention_mask_float) # (2, 768)
# divide by the sequence length
# `attention_mask_float.sum(dim=1, keepdim=True)`: (2, 1)
# we use `torch.clamp` and set min because we don't want to produce zeros.
emb_mean = emb_mean / torch.clamp(
attention_mask_float.sum(dim=1, keepdim=True), min=1e-9
)
The key line is torch.einsum("BSD,BS->BD", embs, attention_mask_float)
where we use einsum
to do row-wise aggregation. B
, S
, D
here represent the batch size, sequence length, and the dimension of embeddings. In our case, they’re (2, 7, 768)
, respectively.
The string "BSD,BS->BD"
is an equation that specifies the operation to be performed. In this equation:
BSD
andBS
are the dimensions of the input tensors embs and attention_mask_float respectively.->BD
specifies the dimensions of the output tensor.
The operation being performed here is essentially a weighted sum along the second dimension (S
) of the embs tensor, with the weights provided by attention_mask_float.
In other words, for each batch (B) and each dimension (D), it’s multiplying the embeddings (embs) by the attention mask (attention_mask_float), and then summing over the sequence dimension (S). This could be part of an attention mechanism in a neural network, where the attention mask determines the weight of each element in the sequence.
Since we’re calculating the “mean,” not the sum, we also need to divide the results by the number of valid tokens in each sentence, and this step is done in:
emb_mean = emb_mean / torch.clamp(
attention_mask_float.sum(dim=1, keepdim=True), min=1e-9
)
we use torch.clamp
and set min because we don’t want to produce zeros.