4 minutes
Multi Headed Attention in TFT
Multi-headed Attention in TFT
Multi-headed attention is a core mechanism in the Temporal Fusion Transformer (TFT) and other transformer-based models that allows the model to focus on different parts of the input sequence simultaneously. This is particularly useful for capturing various patterns and dependencies within time-series data. Here’s a detailed explanation of multi-headed attention in the context of TFT, along with an example.
Multi-Headed Attention
Multi-headed attention allows the model to look at the input sequence from multiple perspectives by applying attention mechanisms independently and then combining the results. Each “head” in multi-headed attention performs its own attention operation, producing different representations of the input data. The outputs of all heads are then concatenated and transformed to create the final output.
Components of Multi-Headed Attention
Scaled Dot-Product Attention: Each head performs scaled dot-product attention, which involves:
- Query (Q): A set of queries.
- Key (K): A set of keys.
- Value (V): A set of values.
The attention mechanism computes a weighted sum of the values, where the weights are determined by the compatibility of the queries with the corresponding keys. Mathematically: $$ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V $$ Here, $( d_k )$ is the dimension of the keys.
Multiple Heads: Multiple sets of Q, K, and V matrices are used to capture different aspects of the input data. If there are $( h )$ heads, then for each head $( i )$: $$ \text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) $$ where $( W_i^Q, W_i^K, W_i^V )$ are learned linear projection matrices for the queries, keys, and values.
Concatenation and Final Linear Transformation: The outputs of all heads are concatenated and passed through a final linear layer: $$ \text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \text{head}_2, \ldots, \text{head}_h)W^O $$ where $( W^O )$ is the output projection matrix.
Example in Temporal Fusion Transformers
Con€ider a time-series forecasting task where we want to predict future sales based on historical sales data, promotional activities, holidays, and other features. The TFT uses multi-headed attention to effectively learn and weigh the importance of these various inputs over different time steps.
Example Walkthrough
Input Data:
- Historical sales data.
- Promotional activity indicators.
- Holiday indicators.
- Additional covariates (e.g., economic indicators, weather data).
Encoding Input Sequences: The input data is encoded into query, key, and value vectors. For instance, let’s assume we have:
- Sales data from the past 30 days.
- Promotional activity for the same period.
- Holiday indicators for the same period.
Attention Heads: The multi-headed attention mechanism might use several heads to focus on different aspects of the input data:
- Head 1: Might focus on short-term dependencies in the sales data (e.g., last week’s sales).
- Head 2: Might focus on the impact of promotional activities.
- Head 3: Might consider the influence of holidays.
- Head 4: Might capture long-term trends in the sales data.
Combining Attention Outputs: Each head processes the input sequences and computes the attention scores and weighted sums independently. The outputs from all heads are concatenated and passed through a final linear layer to produce the combined attention output.
Final Forecasting: The combined attention output, which now integrates various patterns and dependencies learned by different heads, is used to make the final sales forecast for the desired future period.
Visualization Example
Imagine a scenario where we have the following inputs for a retail store:
- Sales Data: ( [100, 150, 120, 170, 160, 130, 140] )
- Promotion Indicator: ( [0, 1, 0, 1, 0, 0, 1] )
- Holiday Indicator: ( [0, 0, 1, 0, 0, 1, 0] )
Head 1:
- Focus: Recent sales trend.
- Query: Last three sales data points.
- Keys/Values: All sales data points.
Head 2:
- Focus: Effect of promotions.
- Query: Promotion indicator.
- Keys/Values: Sales data points.
Head 3:
- Focus: Impact of holidays.
- Query: Holiday indicator.
- Keys/Values: Sales data points.
Each head computes the attention scores and weighted averages based on its focus, and then these outputs are combined to form the final attention output.
Summary
Multi-headed attention in Temporal Fusion Transformers allows the model to simultaneously consider various aspects of the input data, such as short-term patterns, long-term trends, and the impact of external factors like promotions and holidays. By using multiple heads, the TFT can learn diverse and complementary features from the time series, leading to more accurate and robust forecasts.