Skip to content

Instantly share code, notes, and snippets.

@munro
Created October 12, 2023 13:43
Show Gist options
  • Save munro/ab97264256b18689e4b6c83a8de4d1c5 to your computer and use it in GitHub Desktop.
Save munro/ab97264256b18689e4b6c83a8de4d1c5 to your computer and use it in GitHub Desktop.
PyTorch: iTransformer: Inverted Transformers Are Effective for Time Series Forecasting
# @TODO this runs, but isn't fitting!
from dataclasses import dataclass
import torch
from torch import nn
@dataclass(frozen=True)
class ForecastPrediction:
forecast: torch.Tensor
correlation_maps: tuple[torch.Tensor, ...]
class iTransformer(nn.Module):
def __init__(
self,
*,
variates: int,
embed_dim: int,
seq_length: int,
blocks: int,
forecast_length: int,
dropout: float = 0.1,
):
super(iTransformer, self).__init__()
# Step 2: Multi-layer Perceptron to embed series into variate tokens
self.embedding = nn.Sequential(
# nn.Linear(seq_length, seq_length),
# nn.ReLU(),
nn.Linear(seq_length, embed_dim),
)
# TrmBlocks
self.trm_blocks = nn.ModuleList()
for _ in range(blocks):
trm_block = TrmBlock(embed_dim, variates, dropout=dropout)
self.trm_blocks.append(trm_block)
# Step 11: MLP for projecting tokens back to predicted series
self.projection = nn.Sequential(
nn.Linear(embed_dim, embed_dim),
nn.ReLU(),
nn.Linear(embed_dim, forecast_length),
)
def forward(self, x):
# Step 1: Transpose the input
x = x.transpose(-1, -2)
# Step 3: Pass through MLP
h = self.embedding(x)
# Step 4: Run through iTransformer blocks
correlation_maps = []
for trm_block in self.trm_blocks:
h, correlation_map = trm_block(h)
correlation_maps.append(correlation_map)
# Step 11: Project tokens back to predicted series using MLP
y = self.projection(h)
# Step 12: Transpose the result
y = y.transpose(-1, -2)
# Step 13: Return the prediction result
return ForecastPrediction(y, tuple(correlation_maps))
class TrmBlock(nn.Module):
def __init__(self, embed_dim: int, variates: int, *, dropout: float = 0.1):
super(TrmBlock, self).__init__()
# Step 5: Self-attention layers for each block
self.self_attn = MultivariateAttention(variates)
# Step 6: H'-1 = LayerNorm(HI-1 + Self-Attn(HI-1))
self.layer_norm1 = nn.LayerNorm(embed_dim)
# Step 7: Feed-forward networks for each block
self.feed_forward = nn.Sequential(
nn.Linear(embed_dim, embed_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(embed_dim, embed_dim),
)
# Step 9: Layer normalizations for each block
self.layer_norm2 = nn.LayerNorm(embed_dim)
def forward(self, x):
# Multivariate self-attention layer
attn_output, correlation_map = self.self_attn(x)
# Step 6: LayerNorm is adopted on series representations to reduce variates discrepancies.
h1 = self.layer_norm1(x + attn_output)
# Step 7: Feed-forward network is utilized for series representations, broadcasting to each token.
ff_output = self.feed_forward(h1)
# Step 9: LayerNorm is adopted on series representations to reduce variates discrepancies.
h2 = self.layer_norm2(h1 + ff_output)
return h2, correlation_map
class MultivariateAttention(nn.Module):
def __init__(self, variates: int, num_heads: int = 1, *, dropout: float = 0.0) -> None:
super().__init__()
self.self_attn = nn.MultiheadAttention(embed_dim=variates, num_heads=num_heads, dropout=dropout)
def forward(self, x):
# Apply attention across variates (not tokens)
x = x.transpose(-1, -2)
attn_output, correlation_map = self.self_attn(x, x, x)
attn_output = attn_output.transpose(-1, -2)
return attn_output, correlation_map
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment