Graphic representation of TFT, an interpretable transformer, showcasing its architecture and features.

TFT: an Interpretable Transformer

A deep exploration of TFT, its implementation using Darts and how to interpret a Transformer
This article was authored by Rafael Guedes

Introduction

Every company in the world needs forecasting to plan their operations regardless the sector in which they operate. There are several forecast use cases to solve in companies such as sales for yearly planning, customer service contacts for monthly planning of agents for each language, sku sales to plan production and/or procurement and so on.

Although, there are different use cases, all of them share one need from their stakeholders: Interpretability! If you deployed a forecast model in the past for a stakeholder, you came across to the question: ‘why is the model making such prediction?’

In this article I explore TFT, an interpretable Transformer for time series forecasting. I also provide a step-by-step implementation of TFT to forecast weekly sales in a dataset from Walmart using Darts (a forecasting library for Python). And finally, I show how to interpret the model and its performance for a 16 week horizon forecast in the Walmart dataset.

TFT: Temporal Fusion Transformers

What is it?

When it comes to time series forecasting, usually they are influenced not only by their historical values but also on other inputs. They might contain a mix of complex inputs like static covariates (i.e. time-invariant features like the brand of a product), dynamic covariates with known future inputs like the product discount and other dynamic covariates with unknown future inputs such as the number of visitors for the next weeks.

Several Deep Learning models have been proposed to tackle the presence of multiple inputs for time series forecasting but they are typically ‘black-box’ models which do not allow to understand how each component is impacting the forecast produced.

Temporal Fusion Transformers (TFT) [1] is an attention-based architecture that combines multi-horizon forecasting with interpretable insights. It has recurrent layers to learn temporal relationships at different scales, self-attention layers for interpretability, variable selection networks to perform feature selection, gating layers to suppress unnecessary components and its loss function is quantile loss to produce forecast intervals. In Figure 2 you can see TFT architecture that will be explained in more detail in the next section.

TFT architecture (source)
TFT architecture (source)
How does it work?

TFT has 5 major components:

1. Gating Mechanisms powered by Gated Residual Network (GRN) that gives the flexibility to apply non-linear processing only when needed. This is important because it is hard to know in advance the non-linear relationship between dynamic covariates and the target. When we have small and noisy datasets, simpler models can be more beneficial and the non-linear processes can be skipped.

The GRN works the following way:

  • It starts with a Primary Input a and a Context Vector (which is the result of the static encoder — will be explained later).
  • Those inputs go through a Dense Layer with ELU [2] activation function that acts as identity function when W2,ω a + W3,ω c + b2,ω >> 0 and as a constant when W2,ω a + W3,ω c + b2,ω << 0.
  • Then its output goes through a new Dense Layer which produces the input for the Gated Linear Units (GLUs) [3]. GLU has a sigmoid activation function that is responsible for controlling how much the GRN contributes to the original input a.
  • Due to the sigmoid function GLU can produce an output of 0 which will make the model skip the non-linear contribution and the Layer Normalization will just receive the original from the skip connection.
The inside of Gated Residual Network and how it works (image made by the author)
The inside of Gated Residual Network and how it works (image made by the author)

2. Variable Selection Networks (VSN) help the model to weight the relevance and contribution of each static and dynamic covariates. Apart from providing which features are most important for the forecasting, it also performs feature selection to remove any unnecessary noisy inputs that can negatively impact performance. Each type of input (static, past and future) has its own VSN represented by different colours in Figure 2.

The input features are transformed before getting into VSN. Categorical inputs are encoded into a d-dimensional embedding vector while numerical features are linearly transformed into a d-dimensional vector.

After that, VSN unfolds into two branches:

1st Branch:

  • Each transformed feature E(j) at time are concatenated into a vector of all past inputs at time such as [E(1)t, E(2)t, …, E(j-1)t, E(j)t], where j denotes a specific feature.
  • This vector is concatenated to a Context Vector that goes through a GRN where a non-linear transformation is applied as explained previously.
  • The output of GRN is passed to a Softmax layer that produces a vector with the weights for each feature. The feature selection is performed in this step since Softmax can produce any value from 0 to 1.
Process to get the feature importance of each feature (image made by the author)
Process to get the feature importance of each feature (image made by the author)

2nd Branch:

  • Each transformed feature at time t feeds an independent GRN producing the output ~E though a non-linear processing.
Transformation of E(j) at time t with GRN (image made by the author)
Transformation of E(j) at time t with GRN (image made by the author)

Combination Step:

  • With the feature importance vector and the non-linear transformations of E, a element wise combination of both vectors is performed to generated a processed feature vector weight by their relevance.
Combination of Feature weight and Transformed features (image made by the author)
Combination of Feature weight and Transformed features (image made by the author)

3. Static Covariate Encoders encode the static covariates into four different vectors using four different GRNs:

  • Context vector s used for temporal variable selection in VSNs
  • Context vector c and used for local processing of temporal features in LSTM Encoder-Decoder
  • Context vector used to enrich temporal features with static information in the Enrichment Layer
Static Covariate Encoder into 4 different context vectors for different uses (image made by the author)
Static Covariate Encoder into 4 different context vectors for different uses (image made by the author)

4. Temporal Processing is important in time series because often the surrounding observations are the most useful for future predictions. This local context was already developed for attention-based architectures, however they are only suitable for observed inputs and cannot handle known future inputs at the same time.

To overcome this problem, the authors proposed a sequence-to-sequence model to handle past and future known inputs. They feed past inputs into a LSTM Encoder and known future inputs into a LSTM Decoder. Both Encoder and Decoder also use the context vectors c and as inputs so that static metadata can influence local processing when creating the temporal features.

LSTM Encoder-Decoder Model (image made by the author)
LSTM Encoder-Decoder Model (image made by the author)

The output of both Encoder and Decoder is combined with the context vector and sent to an individual GRN with shared weights in the Enrichment Layer that enhances temporal features with static metadata.

Final Temporal Features created in the Enrichment Layer (image made by the author)
Final Temporal Features created in the Enrichment Layer (image made by the author)

Finally, the final temporal features enriched with static metadata are fed to an interpretable multi-head attention layer that learns the relevance of each time step t in respect to the rest of the input sequence that precedes it.

Temporal Interpretability of Temporal Features in relation to t (image made by the author)
Temporal Interpretability of Temporal Features in relation to t (image made by the author)

5. Quantile Prediction is achieved by the prediction of various percentiles at each time step. The forecast is generated using a linear transformation of the output from the temporal fusion decoder.

Information flow from LSTM Decoder to the Denser layer that performs the quantile regression (image made by the author)
Information flow from LSTM Decoder to the Denser layer that performs the quantile regression (image made by the author)

How to use and interpret TFT in practice

This section covers a step by step implementation of TFT using the same dataset from my previous post about TiDE which is a weekly sales dataset from Walmart available on kaggle (License CC0: Public Domain).

I will use the implementation in Darts to train, predict and interpret TFT.

The dataset has 2 years and 8 months of weekly sales and 16 columns:

  • Store — store number and one of the static covariates
  • Dept — department number and another static covariate
  • Type — type of store and another static covariate
  • Size — size of the store and the last static covariate
  • Date — the temporal index of the time series which is weekly and it will be used to extract dynamic covariates like the week number and the month
  • Weekly_Sales — the target variable
  • IsHoliday — a dynamic covariate that identifies if there is a holiday in a specific week
  • Temperature — a dynamic covariate with the average temperature in a specific week
  • Fuel_Price — a dynamic covariate with the price of fuel in a specific week
  • MarkDown 1,2,3,4 and 5 —a dynamic covariate with average discounts in a specific week
  • CPI — a dynamic covariate with the consumer price index
  • Unemployment — a dynamic covariate with the unemployment rate

We start by importing the libraries and defining global variables like the date column, target column, static covariates, dynamic covariates to fill with 0, dynamic covariates to fill with linear interpolation, the frequency of our series, the forecast horizon and the scalers to use:

import pandas as pd
import numpy as np
from datetime import timedelta
import matplotlib.pyplot as plt

from darts import TimeSeries
from darts.dataprocessing.pipeline import Pipeline
from darts.models import TFTModel
from darts.dataprocessing.transformers import Scaler
from darts.utils.timeseries_generation import datetime_attribute_timeseries
from darts.utils.likelihood_models import QuantileRegression
from darts.dataprocessing.transformers import StaticCovariatesTransformer, MissingValuesFiller

TIME_COL = "Date"
TARGET = "Weekly_Sales"
STATIC_COV = ["Store", "Dept", "Type", "Size"]
DYNAMIC_COV_FILL_0 = ["IsHoliday", 'MarkDown1', 'MarkDown2', 'MarkDown3', 'MarkDown4', 'MarkDown5']
DYNAMIC_COV_FILL_INTERPOLATE = ['Temperature', 'Fuel_Price', 'CPI', 'Unemployment']
FREQ = "W-FRI"
FORECAST_HORIZON = 16 # weeks
SCALER = Scaler()
TRANSFORMER = StaticCovariatesTransformer()
PIPELINE = Pipeline([SCALER, TRANSFORMER])

The default scaler is MinMax Scaler, but we can use any we want from scikit-learn as long as it has fit()transform() and inverse_transform() methods. The same happens for the transformer which by default is Label Encoder from scikit-learn.

After that, we load our dataset and we enrich it with those exogenous features as I mentioned in the dataset description:

# load data and exogenous features
df = pd.read_csv('data/train.csv')
store_info = pd.read_csv('data/stores.csv')
exo_feat = pd.read_csv('data/features.csv').drop(columns='IsHoliday')


# join all data frames
df = pd.merge(df, store_info, on=['Store'], how='left')
df = pd.merge(df, exo_feat, on=['Store', TIME_COL], how='left')

Once the dataset is loaded we need to apply some preprocessing to clean up the data:

  • We set the time column as pd.datetime
  • We convert negative values to 0 (those negative values might indicate returns but I did not spend a lot of time looking into it since it is out of the scope for this article)
  • We fill missing values in Markdown columns with 0, since we assume that when the value is missing is due to lack of promotions
  • We convert the boolean column that identifies a holiday in a specific week to a binary column
  • We transform the Size static covariate from continuous to categorical. When the size is lower than the percentile 25 then is ‘small’, when is higher than the percentile 75 then is ‘large’ and, finally, when it is between the percentile 25 and 75 then is ‘medium’.
  • Finally, we forecast only the 7 stores with the highest volume of sales to reduce run time (I ran it once with all stores and it took me 11 hours to train the TFT model 😅).
df[TIME_COL] = pd.to_datetime(df[TIME_COL])
df[TARGET] = np.where(df[TARGET] < 0, 0, df[TARGET]) # remove negative values
df[['MarkDown1', 'MarkDown2', 'MarkDown3', 'MarkDown4','MarkDown5']] = df[['MarkDown1', 'MarkDown2', 'MarkDown3', 'MarkDown4','MarkDown5']].fillna(0) # fill missing values with nan
df["IsHoliday"] = df["IsHoliday"]*1 # convert boolean into binary
df["Size"] = np.where(df["Size"] < store_info["Size"].quantile(0.25), "small",
np.where(df["Size"] > store_info["Size"].quantile(0.75), "large",
"medium")) # make size a categorical variable

# reduce running time by forecasting only top 7 stores
top_7_stores = df.groupby(['Store']).agg({TARGET: 'sum'}).reset_index().sort_values(by=TARGET, ascending=False).head(7)
df = df[df['Store'].isin(top_7_stores['Store'])]

With the data cleaned, we can split the data between train and test (I decided to use the last 16 weeks of data for our test set). And transform the pandas data frame into Darts TimeSeries format.

  • Using the function TimeSeries.from_group_dataframe we can easily define the static covariates, our target, the column with the time reference and the frequency of the series.
  • I also used the fill_missing_dates argument to fill the target variable with 0 in case some of our series have a gap between weeks.
# 16 weeks to for test
train = df[df[TIME_COL] <= (max(df[TIME_COL])-timedelta(weeks=FORECAST_HORIZON))]
test = df[df[TIME_COL] > (max(df[TIME_COL])-timedelta(weeks=FORECAST_HORIZON))]


# read train and test datasets and transform train dataset
train_darts = TimeSeries.from_group_dataframe(
df=train,
group_cols=STATIC_COV,
time_col=TIME_COL,
value_cols=TARGET,
freq=FREQ,
fill_missing_dates=True,
fillna_value=0)

We have the historical data in a TimeSeries format, so now it is time to create the dynamic covariates in the same format:

  • As before, we used the fill_missing_dates argument with fillna_value=0 for those dynamic covariates that I believe lack of data can be replaced with 0.
  • For those which should be replaced with interpolation like Temperature, Fuel Price, CPI and Unemployment, we don’t set the fillna_value argument and we use MissingValuesFiller() function to linear interpolate missing values.
# create dynamic covariates for each serie in the training darts
dynamic_covariates = []
for serie in train_darts:
# add the month and week as a covariate
covariate = datetime_attribute_timeseries(
serie,
attribute="month",
one_hot=True,
cyclic=False,
add_length=FORECAST_HORIZON,
)
covariate = covariate.stack(
datetime_attribute_timeseries(
serie,
attribute="week",
one_hot=True,
cyclic=False,
add_length=FORECAST_HORIZON,
)
)


store = serie.static_covariates['Store'].item()
dept = serie.static_covariates['Dept'].item()
# create covariates to fill with 0
covariate = covariate.stack(
TimeSeries.from_dataframe(df[(df['Store'] == store) & (df['Dept'] == dept)], time_col=TIME_COL, value_cols=DYNAMIC_COV_FILL_0, freq=FREQ, fill_missing_dates=True, fillna_value=0)
)
# create covariates to fill with interpolation
dyn_cov_interp = TimeSeries.from_dataframe(df[(df['Store'] == store) & (df['Dept'] == dept)], time_col=TIME_COL, value_cols=DYNAMIC_COV_FILL_INTERPOLATE, freq=FREQ, fill_missing_dates=True)
covariate = covariate.stack(MissingValuesFiller().transform(dyn_cov_interp))
dynamic_covariates.append(covariate)

We are 2 steps away from fitting and predicting with TFT! We just need to scale the covariates, the historical data and to encode the static covariates.

# scale covariates
dynamic_covariates_transformed = SCALER.fit_transform(dynamic_covariates)


# scale data and transform static covariates
data_transformed = PIPELINE.fit_transform(train_darts)

We are ready to train and predict with TFT:

TFT_params = {
"input_chunk_length": 52, # number of weeks to lookback
"output_chunk_length": FORECAST_HORIZON,
"hidden_size": 2,
"lstm_layers": 2,
"num_attention_heads": 1,
"dropout": 0.1,
"batch_size": 16,
"n_epochs": 3,
"likelihood": QuantileRegression(quantiles=[0.25, 0.5, 0.75]),
"random_state": 42,
"use_static_covariates": True,
"optimizer_kwargs": {"lr": 1e-3},

}
tft_model = TFTModel(**TFT_params)
tft_model.fit(data_transformed, future_covariates=dynamic_covariates_transformed, verbose=False)
pred = PIPELINE.inverse_transform(tft_model.predict(n=FORECAST_HORIZON, series=data_transformed, num_samples=50, future_covariates=dynamic_covariates_transformed))

TFT: Interpretability

TFT is an interpretable Transformer and Darts provide that functionality with just few lines!

We start by defining our explainer that receives as input the model that we just trained and the historical data of the series we want to interpret together with its dynamic covariates. In our case, we will interpret the forecast for Store 2 — Dept 2.

from darts.explainability import TFTExplainer

explainer = TFTExplainer(
tft_model,
background_series=data_transformed[1],
background_future_covariates=dynamic_covariates_transformed[1],
)
explainability_result = explainer.explain()

After that, we can interpret the model through its 2 variants:

Variable Network Selection

  • Encoder Importance shows that seasonality features (week and month) have the highest importance for our model, followed by IsHoliday and Markdown5.
plt.rcParams["figure.figsize"] = (10,5)
plt.barh(data=explainer._encoder_importance.melt().sort_values(by='value').tail(10), y='variable', width='value')
plt.xlabel('Importance')
plt.ylabel('Feature')
plt.title('Encoder Importance')
plt.show()
 
Encoder Importance (image made by the author)
Encoder Importance (image made by the author)
  • Decoder Importance shows that seasonality features (week and month) have the highest importance for our model.
plt.rcParams["figure.figsize"] = (10,5)
plt.barh(data=explainer._decoder_importance.melt().sort_values(by='value').tail(10), y='variable', width='value')
plt.xlabel('Importance')
plt.ylabel('Feature')
plt.title('Decoder Importance')
plt.show()
 
Decoder Importance (image made by the author)
Decoder Importance (image made by the author)
  • Static Covariate Importance shows that the feature we created based on store Size is the most important static covariate.
plt.rcParams["figure.figsize"] = (10,5)
plt.barh(data=explainer._static_covariates_importance.melt().sort_values(by='value').tail(10), y='variable', width='value')
plt.xlabel('Importance')
plt.ylabel('Feature')
plt.title('Static Cov Importance')
plt.show()
Static Covariate Importance (image made by the author)
Static Covariate Importance (image made by the author)

Temporal Multi-Head Attention Layer

This plot shows the importance of historical data for each time step in the forecast horizon and, as expected, closer we are from the cut off date, higher is the importance of past values.

explainer.plot_attention(explainability_result, plot_type="all", show_index_as='time')
 
Historical data Importance (image made by the author)
Historical data Importance (image made by the author)

TFT: Performance

For Store 2 and Dept 2, our model had a performance of 3.58% MAPE and 3220.08 RMSE in the test set.

It managed to correctly predict most of the drops and spikes, and only 2 actual values are outside of the quantile ranges as you can see in Figure 16.

from darts.metrics import mape, rmse

def eval_model(val_series, pred_series):
# plot actual series
plt.figure(figsize=(9, 6))
val_series[: pred_series.end_time()].plot(label="actual")
# plot prediction with quantile ranges
pred_series.plot(
low_quantile=0.25, high_quantile=0.75, label=f"{int(0.25 * 100)}-{int(0.75 * 100)}th percentiles"
)
plt.title(f"MAPE: {round(mape(val_series, pred_series),2)}% | RMSE: {round(rmse(val_series, pred_series),2)}")
plt.legend()
eval_model(test_darts[1], pred[1])
 
TFT performance for Store 2 Dept 2 (image made by the author)
TFT performance for Store 2 Dept 2 (image made by the author)

Conclusion

The complexity of state-of-the-art models have been increasing over time which makes difficult to explain how they work for audiences that does not have a background in AI.

Regardless the type of model you deploy in your organisation, whether it is a simple classifier or a complex LLM, your stakeholders will always ask you at some point ‘why is the model making such prediction?’. Therefore, when it comes to research a model to solve a business problem, it is not just a matter of accuracy but also a matter of interpretability.

In this article, we explored a Transformer model for time series that is able to provide insights of what is impacting the predictions and how to interpret them.

References

[1] Bryan Lim, Sercan O. Arik, Nicolas Loeff, Tomas Pfister. Temporal Fusion Transformers for Interpretable Multi-horizon Time Series Forecasting. arXiv:1912.09363, 2020.

[2] Djork-Arné Clevert, Thomas Unterthiner, Sepp Hochreiter. Fast and Accurate Deep Network Learning by Exponential Linear Units (ELUs).
arXiv:1511.07289, 2016.

[3] Yann N. Dauphin, Angela Fan, Michael Auli, David Grangier. Language Modeling with Gated Convolutional Networks. arXiv:1612.08083, 2017.

More articles: https://zaai.ai/lab/