King Yuen Chiu, Anthony

Abstract

The Attention Mechanism is a flexible and powerful framework. However, it might be super confusing because of its flexibility. This article explains the concept of Query, Key, Value, and Output with an analogy of an Information Retrieval System. This article also shows how to achieve different Attention Mechanisms by varying:

  1. Query-Key compatibility function
  2. Query, Key, and Value choices

For example, the innovation of Self-Attention is to use the same vector as Query, Key, and Value, while the innovation of Scaled Dot-Product Attention is to have a different Query-Key compatibility function.

Table of Contents

Information Retrieval System

An Information Retrieval System stores Values indexed by Keys. A Query is used to retrieve related Key-Value pairs. The system returns an Output that is the associated Value or a transformed version of the associated Value.

Exact Matching

Information Retrieval
Fig.1 Finding the Exact Income by a Person's Name

In Fig.1, we have a basic system that allows users to find the exact income by a person's name. The system stores the Values (Income) indexed by Keys (Name). When a user enters a Query (Name), the system returns an Output (Income) where it is the Value (Income) that is related to the Query (Name).

Query-Key Compatibility

Information Retrieval (Fuzzy Matching)
Fig.2 Estimating a Person's Income by His/Her Name

Sometimes, we might not want an exact matching between Query and Key. For example, in Fig.2, we want to estimate a person's income by their name. The Query might not be the same as the Key (Name). In this case, the system computes Query-Key compatibility scores between the Query and all the Keys (Name) in the system.

We can use different functions to measure compatibility. The compatibility function takes the Query and Key as inputs and gives a score.

Note: Query-Key compatibility has also been called similarity, relevancy, and alignment in different places.

Retrieval Output - Weighted Sum of Values

The returned Output is the weighted sum of all Values, which is weighted by the Query-Key compatibility scores.

For example, in Fig.2,

  1. If the scores are [0.9,0.1,0.0][0.9, 0.1, 0.0] respectively, the returned Output will be 1000×0.9+2000×0.1+3000×0.0=11001000 \times 0.9 + 2000 \times 0.1 + 3000 \times 0.0 = 1100.

  2. If the scores are [1.0,0.0,0.0][1.0, 0.0, 0.0] respectively, the returned Output will be 10001000, the same as retrieving the Value with the highest associated compatibility score.

Information Retrieval in Sequential Data

We have discussed the Query-Key compatibility in terms of an Information Retrieval System, but how can we apply it into Deep Learning? We can use different Querys, Keys, and Values choices for various NLP tasks. Let's consider a simple task below:

For a 5-word sentence: "i like to eat apple", we want to know if the word "apple" is a fruit or a company.

Let's have the following Querys, Keys, and Values choices:

  • Keys = Values as five embedding vectors of each word = k1s5\mathbf{k}_{1 \le s \le 5}
  • Query as the embedding vector of the word "apple" = q\mathbf{q}

Note: This configuration is the same as the Self-Attention (Will be discussed below).

To understand if the word "apple" is a fruit or a company, we need to know the context of the sentence. We can compute compatibility scores to determine which words we should focus on to decide if the word "apple" is a fruit or a company.

Ideally, we want the Key "eat" to have a higher score than the other Keys when Query is "apple", for example (q,k4)(\mathbf{q}, \mathbf{k}_4) should be more compatible than (q,k3)(\mathbf{q}, \mathbf{k}_3).

Let's make up some scores such that only the Keys "eat" and "apple" have a high score with the Query "apple".:

[0.0,0.0,0.0,0.4,0.6][0.0, 0.0, 0.0, 0.4, 0.6]

Then the Output for the Query "apple" will be the weighted sum of all Values:

o=0.4×k_4+0.6×k_5\mathbf{o}=0.4 \times \mathbf{k}\_4 + 0.6 \times \mathbf{k}\_5

The General Form of Attention Mechanism

An attention function can be described as mapping a query and a set of key-value pairs to an output, where the query, keys, values, and output are all vectors. The output is computed as a weighted sum of the values, where the weight assigned to each value is computed by a compatibility function of the query with the corresponding key. -- Attention Is All You Need1

Query, Key, Value, and Output Vectors

In the setting of deep learning, we have a Query vector q\mathbf{q}, a Key vector k\mathbf{k}, and a Value vector v\mathbf{v}. They are in their own spaces, that's qRdq\mathbf{q} \in \R^{d_q}, kRdk\mathbf{k} \in \R^{d_k} and vRdv\mathbf{v} \in \R^{d_v}. The returned output vector o\mathbf{o} is in the Value space, that's oRdv\mathbf{o} \in \R^{d_v}.

Recall the Key-Value setting for Sequential Data above; let's say there is a sentence with length SS, then there will be SS Key vectors and SS Value vectors. We can represent them as two matrices:

KRS×dk=[k1T,k2T...kST]VRS×dv=[v1T,v2T...vST]\mathbf{K} \in \R^{S \times d_k} = [\mathbf{k}_1^T, \mathbf{k}_2^T...\mathbf{k}_S^T]\\ \mathbf{V} \in \R^{S \times d_v} = [\mathbf{v}_1^T, \mathbf{v}_2^T...\mathbf{v}_S^T]

Query-key Compatibility Function

A Query-Key compatibility function fcompatibility:(Rdq,Rdk)Rf_{\mathbf{compatibility}}: (\R^{d_q}, \R^{d_k}) \to \R maps a Query vector and a Key vector to a score scaler. The higher the score, the more compatible the Query vector and Key vector are. We will have SS compatibility scores:

1sS:esR=fcompatibility(q,ks)\forall 1 \le s \le S: e_s \in \R = f_{\mathbf{compatibility}}(\mathbf{q}, \mathbf{k}_s)

Attention Weights and the Output Vectors

The attention weights are usually defined by applying softmax to the compatibility scores, such that the sum of the weights is 1. Here is how we compute these SS attention weights:

1sS:asR=exp(es)s=1Sexp(es)\forall 1 \le s \le S: a_s \in \R = \frac{\exp(e_s)}{\sum_{s=1}^S \exp(e_s)}

Then the Output vector oRdv\mathbf{o} \in \R^{d_v} is the weighted sum of those SS Value vectors:

oRdv=sSasvs\mathbf{o} \in \R^{d_v} = \sum_s^S a_s \mathbf{v}_s

Pytorch Implementation of an Abstract Attention Module

Here we implement an abstract attention module in Pytorch. The compute_compatibility_score method is the compatibility function fcompatibilityf_{\mathbf{compatibility}}. The whole implementation is based on the matrix form of the attention mechanism. We assume the second-last of the tensors is the sequence length, and the last is the embedding dimension.

import torch
import typing

class AbstractAttention(torch.nn.Module):
  """
  We assume the second-last of the tensors is the sequence length,
  and the last is the embedding dimension.
  """

  def __init__(self):
    super().__init__()

  def compute_compatibility_score(
      self,
      query: torch.Tensor,
      key: torch.Tensor,
  ):
    """
    Args:
      key (torch.Tensor): The Key tensor.
      query (torch.Tensor): The Query.
    Return:
      torch.Tensor: The score tensor with the last two dimensions = (, S_decode, S_encode)
    """
    raise NotImplemented

  def forward(
      self,
      value: torch.Tensor,
      key: torch.Tensor,
      query: torch.Tensor,
      attention_mask: typing.Optional[torch.Tensor] = None,
    ):
    """
    Args:
      value (torch.Tensor): The Value tensor. Last two dimensions = (, S_encode, D_v)
      key (torch.Tensor): The Key tensor. Last two dimensions = (, S_encode, D_k)
      query (torch.Tensor): The Query.  Last two dimensions = (, S_decode, D_q)
      attention_mask: typing.Optional[torch.Tensor]: It is a binary mask being
        passed into the compute_compatibility_score function to post-process
        the calculated scores. (S_decode, S_encode)
    Return:
      torch.Tensor: The Output tensor. Last two dimensions = (, S_decode, D_v)
    """

    # The attentions are among the second last dimension.
    encode_seq_len = value.size()[-2]
    decode_seq_len = query.size()[-2]
    assert value.size()[-2] == key.size()[-2]

    # Compute the compatibility scores. The scoring is based on the last
    # dimension.
    compatibility_scores = self.compute_compatibility_score(query, key)
    assert compatibility_scores.size()[-2:] == (decode_seq_len, encode_seq_len)

    # We are setting unwanted score to -inf, it is because
    # 1. This gives an 0 attention weight after applying Softmax
    # 2. There can be negative values in the scores matrix, therefore 0 is not
    #   a good choice for the fill value.
    if attention_mask is not None:
      scores = scores.masked_fill(attention_mask == 0, torch.Float('-inf'))

    # Compute softmax over the second last dimension
    attention_weights = torch.nn.functional.softmax(
        compatibility_scores, dim=-1
    )
    assert attention_weights.size()[-2:] == (decode_seq_len, encode_seq_len)

    # Compute matrix multiplication over the last two dimensions
    output = torch.matmul(attention_weights, value)
    assert output.size()[-2] == decode_seq_len
    assert output.size()[-1] == value.size()[-1]

    return output, attention_weights

Different Query-Key Compatibility Functions

Dot-Product Attention

Regarding the angular similarity between two vectors, Dot-Product is a common choice. Dot-Product Attention refers to having the compatibility function as the dot-product of the Query vector and Key vector. 2

fcompatibility(q,k)R=qTkf_{\mathbf{compatibility}}(\mathbf{q}, \mathbf{k}) \in \R = \mathbf{q}^T \mathbf{k}

Dot-Product Attention is the simplest compatibility measure form, as there are no hidden learning weights in the compatibility function. However, this method assumes both Query and Key vectors have the same dimension.

Scaled Dot-Product Attention

Scaled Dot-Product Attention is a variant of Dot-Product Attention, where we scale the dot-product by a factor 1dk\frac{1}{\sqrt{d_k}}. 3

fcompatibility(q,k)R=qTkdkf_{\mathbf{compatibility}}(\mathbf{q}, \mathbf{k}) \in \R = \frac{\mathbf{q}^T \mathbf{k}}{\sqrt{d_k}}

Note: The original Transformer paper 3 shows the matrix form of the scaled dot-product attention. The formula above is the vector form that takes two vectors instead of matrices.

Query, Key, and Value Linear Transformation

Linear Projections Before Dot-Product
Fig.3 Linear Transformation Before Dot-Product in Transformer

If the original Query, Key, Valye vectors are not the same size. We can apply linear transformations to the vectors before applying a compatibility function. This projects the input vectors into the same dimension. Fig.3 shows the linear transformations before the scaled dot-product in the standard Transformer model.

Pytorch Implementation of Different Compatibility Functions

Here we implement different compatibility functions in Pytorch. The whole implementation is based on the matrix form of the attention mechanism. We assume the second-last of the tensors is the sequence length, and the last is the embedding dimension.

# Matrix Form Pytorch Implementation of Dot-Product Attention
class DotProductAttention(AbstractAttention):
  def __init__(self):
    super().__init__()

  def compute_compatibility_score(
      self, query: torch.Tensor, key: torch.Tensor,
  ):
    scores = torch.matmul(query, torch.transpose(key, -1, -2))
    return scores

attention = DotProductAttention()

value = torch.rand(100, 20, 10)
key = torch.rand(100, 20, 5)
query = torch.rand(100, 10, 5)
out, attention_weights = attention(value, key, query)

# (torch.Size([100, 10, 10]), torch.Size([100, 10, 20]))
out.size(), attention_weights.size()
# Matrix Form Pytorch Implementation of Scaled Dot-Product Attention
class ScaledDotProductAttention(AbstractAttention):
  def __init__(self):
    super().__init__()

  def compute_compatibility_score(
      self, query: torch.Tensor, key: torch.Tensor,
    ):
    scaler = query.size()[-1] ** 0.5
    scores = torch.matmul(query, torch.transpose(key, -1, -2)) / scaler
    return scores

attention = ScaledDotProductAttention()

value = torch.rand(100, 20, 10)
key = torch.rand(100, 20, 5)
query = torch.rand(100, 10, 5)
out, attention_weights = attention(value, key, query)

# (torch.Size([100, 10, 10]), torch.Size([100, 10, 20]))
out.size(), attention_weights.size()

Different Query, Key, Value Choices

Other than varying the Query-Key compatibility function, we can also change the Query, Key, and Value choices. We can combine different compatibility functions and Query, Key, and Value choices to construct other attention mechanisms. 4

For example, in Transformer, we have Self-Attention with Scaled Dot-Product Compatibility Function.

Attention Layers in Transformer
Fig.4 Different Attention Layers in Transformer

We can examine different Query, Key, and Value choices in the standard Transformer model in Fig.4.

Self-Attention [Fig.4 - 1, 2]

When we have the same Query vectors, Key vectors, and Value vectors, we call it self-attention. The example stated in the "Information Retrieval in Sequential Data" section above is a self-attention example.

Self-Attention has the following choices:

  • Keys = Values are embedding vectors of each word in a sequence = k1sS\mathbf{k}_{1 \le s \le S}
  • Query is an embedding vector of a word in the same sequence = q1sS\mathbf{q}_{1 \le s \le S}
attention = ScaledDotProductAttention()

seq_len = 3
emb_dim = 5
value = torch.rand(2, seq_len, emb_dim)
out, attention_weights = attention(value, value, value)

# (torch.Size([2, 3, 5]), torch.Size([2, 3, 3]))
out.size(), attention_weights.size()

The masked self-attention is a variant of self-attention where we mask the attention weights of the future words:

attention = ScaledDotProductAttention()

seq_len = 3
emb_dim = 5

# Construct a attention mask (, S_decode, S_encode)
prev_token_mask = torch.zeros((seq_len, seq_len))
for decode_idx in range(seq_len):
  # Only enable the encode indices that are less than or equal to the
  # decode idx
  prev_token_mask[decode_idx][0: decode_idx + 1] = 1

value = torch.rand(1, seq_len, emb_dim)
out, attention_weights = attention(value, value, value, attention_mask=prev_token_mask)

# tensor([[[1.0000, 0.0000, 0.0000],
#          [0.4597, 0.5403, 0.0000],
#          [0.3136, 0.3519, 0.3345]]])
attention_weights

Encoder-Decoder [Fig.4 - 3]

When we want to use a source sequence to predict a target sequence, it is common to see an Encoder-Decoder pattern. In Fig.3 we can see the connection between the Encoder and the Decoder is an Attention layer. That attention layer has the following choices:

  • Keys = Values are embedding vectors of each word in the encoded source sequence.
  • Query is an embedding vector of a word in the target Sequence

Footnotes

  1. [Paper] Attention Is All You Need

  2. [Website] Paper with code - Dot-Product Attention

  3. [Website] Paper with code - Scaled Dot-Product Attention 2

  4. [Website] Attention? Attention! | Lil'Log