Created
July 23, 2025 05:35
-
-
Save zhensongren/40fa8c8e93558a379dd395e863bb1689 to your computer and use it in GitHub Desktop.
nbeats
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# %% | |
import pandas as pd | |
import numpy as np | |
from darts import TimeSeries | |
from darts.models import NBEATSModel | |
from darts.metrics import mape | |
from darts.dataprocessing.transformers import Scaler | |
# %% | |
import numpy as np | |
import pandas as pd | |
from darts import TimeSeries | |
# Parameters | |
np.random.seed(42) | |
n_stores = 10 | |
n_products = 7 | |
n_time = 28 | |
dates = pd.date_range(start="2018-01-01", periods=n_time, freq="QE") | |
# Create long-format DataFrame | |
data = [] | |
for store_id in range(n_stores): | |
for product_id in range(n_products): | |
idx = store_id * n_products + product_id | |
seasonal = 10 + 5 * np.sin(np.arange(n_time) * 2 * np.pi / 4) | |
trend = np.linspace(50, 100, n_time) | |
noise = np.random.normal(0, 3, n_time) | |
sales = seasonal + trend + noise + idx * 10 | |
for i in range(n_time): | |
data.append({ | |
"time": dates[i], | |
"store_id": store_id, | |
"product_id": product_id, | |
"sales": sales[i] | |
}) | |
# Combine all into one DataFrame | |
sales_df = pd.DataFrame(data) | |
# Optional: check the structure | |
print(sales_df.head()) | |
# %% | |
# Create multiple TimeSeries, grouped by store_id and product_id | |
series_list = TimeSeries.from_group_dataframe( | |
df=sales_df, | |
time_col="time", | |
group_cols=["store_id", "product_id"], | |
value_cols="sales", | |
fill_missing_dates=True # Optional if some quarters are missing | |
) | |
print(type(series_list)) | |
print(series_list[0].static_covariates) | |
series_list[0].plot() | |
# %% | |
from darts import TimeSeries | |
import pandas as pd | |
import numpy as np | |
# Same number of time points + future steps | |
n_time = 28 | |
forecast_horizon = 8 | |
total_periods = n_time + forecast_horizon | |
# Create market trend data that aligns exactly with your sales data time range | |
market_dates = pd.date_range(start="2018-01-01", periods=total_periods, freq="QE") | |
market_trend = np.linspace(1.0, 2.0, total_periods) + np.random.normal(0, 0.1, total_periods) | |
# Create a DataFrame | |
market_df = pd.DataFrame({ | |
"time": market_dates, | |
"market_index": market_trend | |
}) | |
# Convert to a TimeSeries | |
market_ts = TimeSeries.from_dataframe(market_df, "time", "market_index") | |
# Check | |
print(market_ts.start_time(), market_ts.end_time()) | |
market_ts.plot() | |
# %% | |
# 3. Scale target and covariates | |
target_scaler = Scaler() | |
covariate_scaler = Scaler() | |
series_list_scaled = target_scaler.fit_transform(series_list) | |
market_ts_scaled = covariate_scaler.fit_transform(market_ts) | |
# 🛠️ Convert to float32 for MPS compatibility | |
series_list_scaled = [ts.astype(np.float32) for ts in series_list_scaled] | |
market_ts_scaled = market_ts_scaled.astype(np.float32) | |
# %% | |
# 4. Train NBEATS model using future covariates | |
model = NBEATSModel( | |
input_chunk_length=12, | |
output_chunk_length=8, | |
n_epochs=100, | |
random_state=0 | |
) | |
model.fit(series_list_scaled, | |
# future_covariates=[market_ts_scaled] * n_series, | |
verbose=False) | |
# %% | |
from darts.metrics import mape, wmape | |
# 5. Backtest across all series | |
mape_scores = [] | |
wmape_scores = [] | |
for idx, ts in enumerate(series_list_scaled): | |
forecasts = model.historical_forecasts( | |
series=ts, | |
# train_length=24, # 24 quarters of data to train on | |
forecast_horizon=8, # 8 quarters ahead | |
stride=1, | |
retrain=False, | |
# future_covariates=market_ts_scaled, | |
verbose=False, | |
) | |
actual = ts.slice_intersect(forecasts) | |
mape_val = mape(actual, forecasts) | |
# calculate the weighted mape | |
weighted_mape = wmape(actual, forecasts) | |
print(f"Series {idx} MAPE: {mape_val:.2f}") | |
mape_scores.append(mape_val) | |
wmape_scores.append(weighted_mape) | |
# %% | |
# print the average mape across all series | |
print(f"\nAverage MAPE across all series: {np.mean(mape_scores):.2f}") | |
print(f"\nAverage Weighted MAPE across all series: {np.mean(wmape_scores):.2f}") | |
# %% | |
# Now backtest on one of the series | |
series_to_backtest = series_list_scaled[0] # You can loop over all later | |
backtest_forecasts = model.historical_forecasts( | |
series=series_to_backtest, | |
# train_length=24, # 24 quarters of data to train on | |
forecast_horizon=8, # 8 quarters ahead | |
stride=1, # how often to make forecasts (1 = rolling, higher = spaced) | |
retrain=False, # model is already trained | |
last_points_only=False, | |
verbose=True | |
) | |
series_to_backtest.plot(label="actual") | |
for i, fcast in enumerate(backtest_forecasts): | |
fcast.plot(label=f"forecast_{i}") | |
# %% [markdown] | |
# backtesting the same series on AutoARIMA from Darts | |
# %% | |
from darts.models import AutoARIMA | |
import matplotlib.pyplot as plt | |
# Use a single series for AutoARIMA (no global learning) | |
series_to_backtest = series_list_scaled[0] | |
# Initialize the AutoARIMA model (you can set seasonal=True and m=4 if needed) | |
model = AutoARIMA() | |
# Backtest with historical_forecasts (retrain=True is required because ARIMA needs fitting at each step) | |
backtest_forecasts = model.historical_forecasts( | |
series=series_to_backtest, | |
train_length=12, # 12 quarters of data to train on | |
forecast_horizon=8, # 8 quarters ahead | |
stride=1, | |
retrain=True, # AutoARIMA must be retrained at each step | |
last_points_only=False, | |
verbose=True | |
) | |
# Plot actual series and forecasts | |
series_to_backtest.plot(label="actual") | |
for i, fcast in enumerate(backtest_forecasts): | |
fcast.plot(label=f"forecast_{i}") | |
plt.title("AutoARIMA Backtest Forecasts") | |
plt.legend() | |
plt.show() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment