- Published on
Attention Mechanism - What are Query, Key, and Value?
Abstract
The Attention Mechanism is a flexible and powerful framework. However, it might be super confusing because of its flexibility. This article explains the concept of Query
, Key
, Value
, and Output
with an analogy of an Information Retrieval System. This article also shows how to achieve different Attention Mechanisms by varying:
- Query-Key compatibility function
Query
,Key
, andValue
choices
For example, the innovation of Self-Attention is to use the same vector as Query
, Key
, and Value
, while the innovation of Scaled Dot-Product Attention is to have a different Query-Key compatibility function.
Table of Contents
- Information Retrieval System
- Exact Matching
- Query-Key Compatibility
- Retrieval Output - Weighted Sum of Values
- Information Retrieval in Sequential Data
- The General Form of Attention Mechanism
- Query, Key, Value, and Output Vectors
- Query-key Compatibility Function
- Attention Weights and the Output Vectors
- Pytorch Implementation of an Abstract Attention Module
- Different Query-Key Compatibility Functions
- Dot-Product Attention
- Scaled Dot-Product Attention
- Query, Key, and Value Linear Transformation
- Pytorch Implementation of Different Compatibility Functions
- Different Query, Key, Value Choices
- Self-Attention [Fig.4 - 1, 2]
- Encoder-Decoder [Fig.4 - 3]
Information Retrieval System
An Information Retrieval System stores Values
indexed by Keys
. A Query
is used to retrieve related Key-Value pairs. The system returns an Output
that is the associated Value
or a transformed version of the associated Value
.
Exact Matching
In Fig.1, we have a basic system that allows users to find the exact income by a person's name. The system stores the Values
(Income) indexed by Keys
(Name). When a user enters a Query
(Name), the system returns an Output
(Income) where it is the Value
(Income) that is related to the Query
(Name).
Query-Key Compatibility
Sometimes, we might not want an exact matching between Query
and Key
. For example, in Fig.2, we want to estimate a person's income by their name. The Query
might not be the same as the Key
(Name). In this case, the system computes Query-Key compatibility scores between the Query
and all the Keys
(Name) in the system.
We can use different functions to measure compatibility. The compatibility function takes the Query
and Key
as inputs and gives a score.
Note: Query-Key compatibility has also been called similarity, relevancy, and alignment in different places.
Retrieval Output - Weighted Sum of Values
The returned Output
is the weighted sum of all Values
, which is weighted by the Query-Key compatibility scores.
For example, in Fig.2,
If the scores are respectively, the returned
Output
will be .If the scores are respectively, the returned
Output
will be , the same as retrieving theValue
with the highest associated compatibility score.
Information Retrieval in Sequential Data
We have discussed the Query-Key compatibility in terms of an Information Retrieval System, but how can we apply it into Deep Learning? We can use different Querys
, Keys
, and Values
choices for various NLP tasks. Let's consider a simple task below:
For a 5-word sentence: "i like to eat apple", we want to know if the word "apple" is a fruit or a company.
Let's have the following Querys
, Keys
, and Values
choices:
Keys
=Values
as five embedding vectors of each word =Query
as the embedding vector of the word "apple" =
Note: This configuration is the same as the Self-Attention (Will be discussed below).
To understand if the word "apple" is a fruit or a company, we need to know the context of the sentence. We can compute compatibility scores to determine which words we should focus on to decide if the word "apple" is a fruit or a company.
Ideally, we want the Key
"eat" to have a higher score than the other Keys
when Query
is "apple", for example should be more compatible than .
Let's make up some scores such that only the Keys
"eat" and "apple" have a high score with the Query
"apple".:
Then the Output
for the Query
"apple" will be the weighted sum of all Values
:
The General Form of Attention Mechanism
An attention function can be described as mapping a query and a set of key-value pairs to an output, where the query, keys, values, and output are all vectors. The output is computed as a weighted sum of the values, where the weight assigned to each value is computed by a compatibility function of the query with the corresponding key. -- Attention Is All You Need1
Query, Key, Value, and Output Vectors
In the setting of deep learning, we have a Query
vector , a Key
vector , and a Value
vector . They are in their own spaces, that's , and . The returned output vector is in the Value
space, that's .
Recall the Key-Value setting for Sequential Data above; let's say there is a sentence with length , then there will be Key
vectors and Value
vectors. We can represent them as two matrices:
Query-key Compatibility Function
A Query-Key compatibility function maps a Query
vector and a Key
vector to a score scaler. The higher the score, the more compatible the Query
vector and Key
vector are. We will have compatibility scores:
Attention Weights and the Output Vectors
The attention weights are usually defined by applying softmax to the compatibility scores, such that the sum of the weights is 1. Here is how we compute these attention weights:
Then the Output
vector is the weighted sum of those Value
vectors:
Pytorch Implementation of an Abstract Attention Module
Here we implement an abstract attention module in Pytorch. The compute_compatibility_score
method is the compatibility function . The whole implementation is based on the matrix form of the attention mechanism. We assume the second-last of the tensors is the sequence length, and the last is the embedding dimension.
import torch
import typing
class AbstractAttention(torch.nn.Module):
"""
We assume the second-last of the tensors is the sequence length,
and the last is the embedding dimension.
"""
def __init__(self):
super().__init__()
def compute_compatibility_score(
self,
query: torch.Tensor,
key: torch.Tensor,
):
"""
Args:
key (torch.Tensor): The Key tensor.
query (torch.Tensor): The Query.
Return:
torch.Tensor: The score tensor with the last two dimensions = (, S_decode, S_encode)
"""
raise NotImplemented
def forward(
self,
value: torch.Tensor,
key: torch.Tensor,
query: torch.Tensor,
attention_mask: typing.Optional[torch.Tensor] = None,
):
"""
Args:
value (torch.Tensor): The Value tensor. Last two dimensions = (, S_encode, D_v)
key (torch.Tensor): The Key tensor. Last two dimensions = (, S_encode, D_k)
query (torch.Tensor): The Query. Last two dimensions = (, S_decode, D_q)
attention_mask: typing.Optional[torch.Tensor]: It is a binary mask being
passed into the compute_compatibility_score function to post-process
the calculated scores. (S_decode, S_encode)
Return:
torch.Tensor: The Output tensor. Last two dimensions = (, S_decode, D_v)
"""
# The attentions are among the second last dimension.
encode_seq_len = value.size()[-2]
decode_seq_len = query.size()[-2]
assert value.size()[-2] == key.size()[-2]
# Compute the compatibility scores. The scoring is based on the last
# dimension.
compatibility_scores = self.compute_compatibility_score(query, key)
assert compatibility_scores.size()[-2:] == (decode_seq_len, encode_seq_len)
# We are setting unwanted score to -inf, it is because
# 1. This gives an 0 attention weight after applying Softmax
# 2. There can be negative values in the scores matrix, therefore 0 is not
# a good choice for the fill value.
if attention_mask is not None:
scores = scores.masked_fill(attention_mask == 0, torch.Float('-inf'))
# Compute softmax over the second last dimension
attention_weights = torch.nn.functional.softmax(
compatibility_scores, dim=-1
)
assert attention_weights.size()[-2:] == (decode_seq_len, encode_seq_len)
# Compute matrix multiplication over the last two dimensions
output = torch.matmul(attention_weights, value)
assert output.size()[-2] == decode_seq_len
assert output.size()[-1] == value.size()[-1]
return output, attention_weights
Different Query-Key Compatibility Functions
Dot-Product Attention
Regarding the angular similarity between two vectors, Dot-Product is a common choice. Dot-Product Attention refers to having the compatibility function as the dot-product of the Query
vector and Key
vector. 2
Dot-Product Attention is the simplest compatibility measure form, as there are no hidden learning weights in the compatibility function. However, this method assumes both Query
and Key
vectors have the same dimension.
Scaled Dot-Product Attention
Scaled Dot-Product Attention is a variant of Dot-Product Attention, where we scale the dot-product by a factor . 3
Note: The original Transformer paper 3 shows the matrix form of the scaled dot-product attention. The formula above is the vector form that takes two vectors instead of matrices.
Query, Key, and Value Linear Transformation
If the original Query
, Key
, Valye
vectors are not the same size. We can apply linear transformations to the vectors before applying a compatibility function. This projects the input vectors into the same dimension. Fig.3 shows the linear transformations before the scaled dot-product in the standard Transformer model.
Pytorch Implementation of Different Compatibility Functions
Here we implement different compatibility functions in Pytorch. The whole implementation is based on the matrix form of the attention mechanism. We assume the second-last of the tensors is the sequence length, and the last is the embedding dimension.
# Matrix Form Pytorch Implementation of Dot-Product Attention
class DotProductAttention(AbstractAttention):
def __init__(self):
super().__init__()
def compute_compatibility_score(
self, query: torch.Tensor, key: torch.Tensor,
):
scores = torch.matmul(query, torch.transpose(key, -1, -2))
return scores
attention = DotProductAttention()
value = torch.rand(100, 20, 10)
key = torch.rand(100, 20, 5)
query = torch.rand(100, 10, 5)
out, attention_weights = attention(value, key, query)
# (torch.Size([100, 10, 10]), torch.Size([100, 10, 20]))
out.size(), attention_weights.size()
# Matrix Form Pytorch Implementation of Scaled Dot-Product Attention
class ScaledDotProductAttention(AbstractAttention):
def __init__(self):
super().__init__()
def compute_compatibility_score(
self, query: torch.Tensor, key: torch.Tensor,
):
scaler = query.size()[-1] ** 0.5
scores = torch.matmul(query, torch.transpose(key, -1, -2)) / scaler
return scores
attention = ScaledDotProductAttention()
value = torch.rand(100, 20, 10)
key = torch.rand(100, 20, 5)
query = torch.rand(100, 10, 5)
out, attention_weights = attention(value, key, query)
# (torch.Size([100, 10, 10]), torch.Size([100, 10, 20]))
out.size(), attention_weights.size()
Different Query, Key, Value Choices
Other than varying the Query-Key compatibility function, we can also change the Query
, Key
, and Value
choices. We can combine different compatibility functions and Query
, Key
, and Value
choices to construct other attention mechanisms. 4
For example, in Transformer, we have Self-Attention with Scaled Dot-Product Compatibility Function.
We can examine different Query
, Key
, and Value
choices in the standard Transformer model in Fig.4.
Self-Attention [Fig.4 - 1, 2]
When we have the same Query
vectors, Key
vectors, and Value
vectors, we call it self-attention. The example stated in the "Information Retrieval in Sequential Data" section above is a self-attention example.
Self-Attention has the following choices:
Keys
=Values
are embedding vectors of each word in a sequence =Query
is an embedding vector of a word in the same sequence =
attention = ScaledDotProductAttention()
seq_len = 3
emb_dim = 5
value = torch.rand(2, seq_len, emb_dim)
out, attention_weights = attention(value, value, value)
# (torch.Size([2, 3, 5]), torch.Size([2, 3, 3]))
out.size(), attention_weights.size()
The masked self-attention is a variant of self-attention where we mask the attention weights of the future words:
attention = ScaledDotProductAttention()
seq_len = 3
emb_dim = 5
# Construct a attention mask (, S_decode, S_encode)
prev_token_mask = torch.zeros((seq_len, seq_len))
for decode_idx in range(seq_len):
# Only enable the encode indices that are less than or equal to the
# decode idx
prev_token_mask[decode_idx][0: decode_idx + 1] = 1
value = torch.rand(1, seq_len, emb_dim)
out, attention_weights = attention(value, value, value, attention_mask=prev_token_mask)
# tensor([[[1.0000, 0.0000, 0.0000],
# [0.4597, 0.5403, 0.0000],
# [0.3136, 0.3519, 0.3345]]])
attention_weights
Encoder-Decoder [Fig.4 - 3]
When we want to use a source sequence to predict a target sequence, it is common to see an Encoder-Decoder pattern. In Fig.3 we can see the connection between the Encoder and the Decoder is an Attention layer. That attention layer has the following choices:
Keys
=Values
are embedding vectors of each word in the encoded source sequence.Query
is an embedding vector of a word in the target Sequence