Attention, Please : A Simplified Explanation of AI’s Attention Mechanism
Consider you’re reading a novel, and you come across a dialogue between two characters discussing a third character named Sarah. To fully grasp the conversation’s nuances, you don’t just focus on the words spoken about Sarah. Instead, your mind recalls Sarah’s past actions, personality traits, and relationship dynamics with the speakers. You’re effectively paying ‘attention’ to relevant information to enrich your understanding of the dialogue.
In AI language models, the attention mechanism follows a similar principle. It allows the model to focus on different parts of the input data (like a sentence) when performing tasks such as translation, summarization, or question-answering. This is a departure from previous models that processed words in isolation, missing out on the rich inter-word relationships.
Now, Let’s delve into the mathematical underpinnings of the attention mechanism with a tangible example to illustrate how it operates within the context of AI language models.
Mathematical Foundation of the Attention Mechanism
At its core, the attention mechanism is about assigning different levels of significance to various parts of the input data. Mathematically, it is expressed as:
Here’s what each component represents:
- Q (Query): This is a set of vectors that represent the items you want to compare to every other item. In a sentence, each word can be a query when the model is trying to figure out its context.
- K (Key): The key vectors correspond to the items that will be compared against the queries. In the context of language, this could be every other word in the sentence.
- V (Value): Value vectors are what you actually care about and want to retrieve information about. In NLP, the values often contain the embeddings of the words, which are rich representations of their meanings.
- dk: This is the dimensionality of the key vectors. It’s used to scale the dot product so that it has a manageable range.
- SoftMax Function: It is an activation function that turns the dot product scores into probabilities, which add up to 1.
Now, let’s Break down the Attention Calculation equation :
- Dot Product Between Q and K^T: You calculate the dot product to determine how much each element of the query aligns with each key. It’s like measuring the similarity or relevance of every other word in the context to the current focus word.
- Scaling by dk: This step is necessary to avoid extremely large values for the dot product, which can lead to gradients that are too small for effective learning during backpropagation.
- SoftMax: By applying the SoftMax function, you turn the scores into probabilities. This step is akin to deciding how much you should focus on each part of the input based on their relevance scores.
- Multiplying by V: The output is a weighted sum of the value vectors based on these probabilities. This means you’re assembling a composite representation of the input, prioritizing the most relevant parts.
Example with a Sentence
Imagine the sentence:
“The cat chased the mouse.”
Let’s say we want to calculate the attention for the word :
“chased.”
Here, “chased” is our query (Q). We compare it to all the keys (K), which are the other words in the sentence. Each word has a value (V), which in this case could be their meaning in relation to “chased.”
- Dot Product: We calculate how relevant each word in the sentence is to “chased.”
- Scaling: We adjust these relevance measures.
- SoftMax: We convert the relevance measures into probabilities. Perhaps “cat” and “mouse” are highly relevant to “chased,” so they get higher probabilities.
- Multiplication: We use these probabilities to combine the values of all the words, giving us a new, context-enriched representation of “chased” that takes into account its interaction with “cat” and “mouse.”
Let’s discuss it in detail , First each word in the sentence is converted into a vector using word embeddings. These vectors are high-dimensional and contain semantic information about the words. For simplicity, let’s say our vectors are in 2D space (in reality, they are much higher-dimensional). For instance:
- “The” might be represented as [1, 0]
- “cat” might be [2, 3]
- “chased” might be [5, 1]
- “mouse” might be [4, 4]
In the attention mechanism, these word vectors are transformed into query (Q), key (K), and value (V) matrices through multiplication with weight matrices that the model learns during training. For simplicity, let’s assume our Q, K, and V are the same as the original word vectors.
Then, We calculate the dot product of the query with all the keys to determine the relevance. If we are focusing on “chased”, we compute:
- Dot product of “chased” (Q) with “The” (K)
- Dot product of “chased” (Q) with “cat” (K)
- Dot product of “chased” (Q) with “mouse” (K)
Let’s assume we get the following relevance scores (dot products):
- With “The”: 5×1+1×0=55×1+1×0=5
- With “cat”: 5×2+1×3=135×2+1×3=13
- With “mouse”: 5×4+1×4=245×4+1×4=24
To prevent the SoftMax from having too small gradients, we scale the dot products by dividing with the square root of the dimension of the key vectors. If dk = 2 (the dimensionality of our vectors), we divide each score by 22, resulting in:
- With “The”: 5/25/2
- With “cat”: 13/213/2
- With “mouse”: 24/224/2
These scaled scores are then put through a SoftMax function to convert them into probabilities:
This results in a probability distribution over the words, reflecting how much each word should be attended to in the context of “chased”. Each word’s value vector is then weighted by this probability, and these weighted value vectors are summed to produce the final output vector for the word “chased”.
This process is part of what’s called a single attention “head.” In practice, modern transformer models use multiple heads to capture different types of relationships between words, and the final representation for a word like “chased” is a concatenation of the outputs from all these attention heads.