Created
October 12, 2023 13:43
-
-
Save munro/ab97264256b18689e4b6c83a8de4d1c5 to your computer and use it in GitHub Desktop.
PyTorch: iTransformer: Inverted Transformers Are Effective for Time Series Forecasting
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# @TODO this runs, but isn't fitting! | |
from dataclasses import dataclass | |
import torch | |
from torch import nn | |
@dataclass(frozen=True) | |
class ForecastPrediction: | |
forecast: torch.Tensor | |
correlation_maps: tuple[torch.Tensor, ...] | |
class iTransformer(nn.Module): | |
def __init__( | |
self, | |
*, | |
variates: int, | |
embed_dim: int, | |
seq_length: int, | |
blocks: int, | |
forecast_length: int, | |
dropout: float = 0.1, | |
): | |
super(iTransformer, self).__init__() | |
# Step 2: Multi-layer Perceptron to embed series into variate tokens | |
self.embedding = nn.Sequential( | |
# nn.Linear(seq_length, seq_length), | |
# nn.ReLU(), | |
nn.Linear(seq_length, embed_dim), | |
) | |
# TrmBlocks | |
self.trm_blocks = nn.ModuleList() | |
for _ in range(blocks): | |
trm_block = TrmBlock(embed_dim, variates, dropout=dropout) | |
self.trm_blocks.append(trm_block) | |
# Step 11: MLP for projecting tokens back to predicted series | |
self.projection = nn.Sequential( | |
nn.Linear(embed_dim, embed_dim), | |
nn.ReLU(), | |
nn.Linear(embed_dim, forecast_length), | |
) | |
def forward(self, x): | |
# Step 1: Transpose the input | |
x = x.transpose(-1, -2) | |
# Step 3: Pass through MLP | |
h = self.embedding(x) | |
# Step 4: Run through iTransformer blocks | |
correlation_maps = [] | |
for trm_block in self.trm_blocks: | |
h, correlation_map = trm_block(h) | |
correlation_maps.append(correlation_map) | |
# Step 11: Project tokens back to predicted series using MLP | |
y = self.projection(h) | |
# Step 12: Transpose the result | |
y = y.transpose(-1, -2) | |
# Step 13: Return the prediction result | |
return ForecastPrediction(y, tuple(correlation_maps)) | |
class TrmBlock(nn.Module): | |
def __init__(self, embed_dim: int, variates: int, *, dropout: float = 0.1): | |
super(TrmBlock, self).__init__() | |
# Step 5: Self-attention layers for each block | |
self.self_attn = MultivariateAttention(variates) | |
# Step 6: H'-1 = LayerNorm(HI-1 + Self-Attn(HI-1)) | |
self.layer_norm1 = nn.LayerNorm(embed_dim) | |
# Step 7: Feed-forward networks for each block | |
self.feed_forward = nn.Sequential( | |
nn.Linear(embed_dim, embed_dim), | |
nn.ReLU(), | |
nn.Dropout(dropout), | |
nn.Linear(embed_dim, embed_dim), | |
) | |
# Step 9: Layer normalizations for each block | |
self.layer_norm2 = nn.LayerNorm(embed_dim) | |
def forward(self, x): | |
# Multivariate self-attention layer | |
attn_output, correlation_map = self.self_attn(x) | |
# Step 6: LayerNorm is adopted on series representations to reduce variates discrepancies. | |
h1 = self.layer_norm1(x + attn_output) | |
# Step 7: Feed-forward network is utilized for series representations, broadcasting to each token. | |
ff_output = self.feed_forward(h1) | |
# Step 9: LayerNorm is adopted on series representations to reduce variates discrepancies. | |
h2 = self.layer_norm2(h1 + ff_output) | |
return h2, correlation_map | |
class MultivariateAttention(nn.Module): | |
def __init__(self, variates: int, num_heads: int = 1, *, dropout: float = 0.0) -> None: | |
super().__init__() | |
self.self_attn = nn.MultiheadAttention(embed_dim=variates, num_heads=num_heads, dropout=dropout) | |
def forward(self, x): | |
# Apply attention across variates (not tokens) | |
x = x.transpose(-1, -2) | |
attn_output, correlation_map = self.self_attn(x, x, x) | |
attn_output = attn_output.transpose(-1, -2) | |
return attn_output, correlation_map |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment