Skip to content

Instantly share code, notes, and snippets.

@shortthirdman
Created April 24, 2025 12:11
Show Gist options
  • Save shortthirdman/171b66635f85050c76909c9d6abd7645 to your computer and use it in GitHub Desktop.
Save shortthirdman/171b66635f85050c76909c9d6abd7645 to your computer and use it in GitHub Desktop.
Stock Market Forecasting with Differential Graph Transformer
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"source": [
"# Install Dependencies"
],
"metadata": {
"id": "IroaP_EM55Oc"
}
},
{
"cell_type": "code",
"source": [
"%%capture\n",
"%pip install torch_geometric torch pandas wandb"
],
"metadata": {
"id": "1l1soPp457nm"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"from google.colab import drive\n",
"drive.mount('/content/drive')\n",
"\n",
"import os\n",
"\n",
"# Create a directory in your Google Drive\n",
"workdir = '/content/drive/MyDrive/Colab Notebooks/stock_dgt/'\n",
"\n",
"# Remove and recreate directory\n",
"if os.path.exists(workdir):\n",
" shutil.rmtree(workdir)\n",
"os.makedirs(workdir)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "YzAaU-adD4kF",
"outputId": "5ea9b8be-81c7-4f51-ea74-c5f9b499523b"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount(\"/content/drive\", force_remount=True).\n"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NjKkCviZ0dFF"
},
"source": [
"# Dataset Construction"
]
},
{
"cell_type": "markdown",
"source": [
"## Download Dataset"
],
"metadata": {
"id": "UEE6lvSA0z74"
}
},
{
"cell_type": "code",
"source": [
"# Clone the repository to download the S&P500 stock prices, precomputed correlation matrcies,\n",
"# along with trained model weights for ease of evaluation\n",
"!git clone https://github.com/AlienKevin/sp500.git\n",
"\n",
"import shutil\n",
"import os\n",
"\n",
"repo_name = \"sp500\"\n",
"for file_name in os.listdir(repo_name):\n",
" shutil.move(os.path.join(repo_name, file_name), workdir)\n",
"\n",
"# Remove the cloned repository folder\n",
"shutil.rmtree(repo_name)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "IFxwrsQ71FbF",
"outputId": "f1e67512-c405-4dcf-842d-af7e7a489b8f"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"fatal: destination path 'sp500' already exists and is not an empty directory.\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"## Exploratory Data Analysis Shows Superiority of Mutual Information in Capturing Interstock Relationships"
],
"metadata": {
"id": "mVwZhqFx-KRY"
}
},
{
"cell_type": "code",
"source": [
"import matplotlib.pyplot as plt\n",
"import pandas as pd\n",
"import numpy as np\n",
"\n",
"# Plot the 3 most correlated stocks to the target_stock based on corr_name with scope corr_scope\n",
"def plot_most_correlated_stocks(target_stock, corr_name, corr_scope):\n",
" df = pd.read_csv(f'{workdir}/sp500.csv')\n",
" df['Date'] = pd.to_datetime(df['Date'])\n",
" df = df.set_index('Date')\n",
"\n",
" target_index = df.columns.get_loc(target_stock)\n",
"\n",
" corr = np.loadtxt(f'{workdir}/{corr_name}/{corr_scope}.csv', delimiter=',')\n",
"\n",
" top_3_correlated_indices = corr[target_index].argsort()[-4:][::-1]\n",
" top_3_correlated_stocks = df.columns[top_3_correlated_indices]\n",
"\n",
" plt.clf()\n",
" plt.figure(figsize=(12, 6))\n",
" plt.style.use('default')\n",
"\n",
" for stock in top_3_correlated_stocks:\n",
" if corr_scope.startswith('global'):\n",
" # Plot the entire duration of the dataset for global correlations\n",
" plt.plot(df.index, df[stock], label=stock)\n",
" else:\n",
" # Only plot the time window corresponding to the local correlations\n",
" num_days_in_quarter = 64\n",
" quarter_index = int(corr_scope.split('_')[-1])\n",
" quarter_start_index = quarter_index * num_days_in_quarter\n",
" quarter_end_index = (quarter_index + 1) * num_days_in_quarter\n",
" print('Quarter Start date', df.index[quarter_start_index])\n",
" print('Quarter End date', df.index[quarter_end_index])\n",
" quarter_df = df.iloc[quarter_start_index:quarter_end_index]\n",
" plt.plot(quarter_df.index, quarter_df[stock], label=stock)\n",
"\n",
" plt.title(f\"Top 3 Correlated Stocks with {top_3_correlated_stocks[0]} using {'Global' if corr_scope.startswith('global') else 'Local'} {'Pearson' if corr_name == 'pcc' else 'Mutual Information'}: {', '.join(top_3_correlated_stocks[1:])}\")\n",
" plt.xlabel(\"Date\")\n",
" plt.ylabel(\"Price\")\n",
" plt.legend()\n",
" plt.grid(True)\n",
" plt.show()\n"
],
"metadata": {
"id": "wOQvTHmp-Pdd"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Global Mutual Information Captures Shared Trends Well"
],
"metadata": {
"id": "wq6NugHBLJ8P"
}
},
{
"cell_type": "code",
"source": [
"plot_most_correlated_stocks('AAPL', 'mi', 'global_corr')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 581
},
"id": "G0t3kRxQEFo1",
"outputId": "93a46323-39dc-4a88-969e-04120741a6d1"
},
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 640x480 with 0 Axes>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 1200x600 with 1 Axes>"
],
"image/png": "\n"
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"source": [
"## Global Pearson Struggles with Nonlinearity in the Market"
],
"metadata": {
"id": "Gir8jsqCLTEm"
}
},
{
"cell_type": "code",
"source": [
"plot_most_correlated_stocks('AAPL', 'pcc', 'global_corr')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 581
},
"id": "vqRlDukPEMMy",
"outputId": "3b7200e1-d502-4a20-e663-a480a0c4da0f"
},
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 640x480 with 0 Axes>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 1200x600 with 1 Axes>"
],
"image/png": "\n"
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"source": [
"## Mutual Information and Pearson Perform Similarly Well on a Local Scope (the Length of 1 Fiscal Quarter)"
],
"metadata": {
"id": "hQi4BHLALhc0"
}
},
{
"cell_type": "code",
"source": [
"plot_most_correlated_stocks('AAPL', 'mi', 'local_corr_38')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 720
},
"id": "4c0m35WzGEdh",
"outputId": "a286a360-5fc3-4abd-83dc-6abed024de7d"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Quarter Start date 2024-07-03 00:00:00+00:00\n",
"Quarter End date 2024-10-03 00:00:00+00:00\n",
"Quarter Start date 2024-07-03 00:00:00+00:00\n",
"Quarter End date 2024-10-03 00:00:00+00:00\n",
"Quarter Start date 2024-07-03 00:00:00+00:00\n",
"Quarter End date 2024-10-03 00:00:00+00:00\n",
"Quarter Start date 2024-07-03 00:00:00+00:00\n",
"Quarter End date 2024-10-03 00:00:00+00:00\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 640x480 with 0 Axes>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 1200x600 with 1 Axes>"
],
"image/png": "\n"
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"source": [
"plot_most_correlated_stocks('AAPL', 'pcc', 'local_corr_38')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 720
},
"id": "6xvYLO-_GRW5",
"outputId": "d523a91f-68f1-4152-90c2-56235522fb6b"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Quarter Start date 2024-07-03 00:00:00+00:00\n",
"Quarter End date 2024-10-03 00:00:00+00:00\n",
"Quarter Start date 2024-07-03 00:00:00+00:00\n",
"Quarter End date 2024-10-03 00:00:00+00:00\n",
"Quarter Start date 2024-07-03 00:00:00+00:00\n",
"Quarter End date 2024-10-03 00:00:00+00:00\n",
"Quarter Start date 2024-07-03 00:00:00+00:00\n",
"Quarter End date 2024-10-03 00:00:00+00:00\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 640x480 with 0 Axes>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 1200x600 with 1 Axes>"
],
"image/png": "\n"
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"source": [
"## Construct Temporal PyG Dataset"
],
"metadata": {
"id": "yHdXqx-H2yhI"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "xjihcgYj0dFH"
},
"outputs": [],
"source": [
"# Copied from PyG temporal rather than imported because the library has dependency issues with PyG\n",
"# https://pytorch-geometric-temporal.readthedocs.io/en/latest/_modules/torch_geometric_temporal/signal/dynamic_graph_temporal_signal.html\n",
"\n",
"from typing import Sequence, Union\n",
"import numpy as np\n",
"\n",
"Edge_Indices = Sequence[Union[np.ndarray, None]]\n",
"Edge_Weights = Sequence[Union[np.ndarray, None]]\n",
"Node_Features = Sequence[Union[np.ndarray, None]]\n",
"Targets = Sequence[Union[np.ndarray, None]]\n",
"Additional_Features = Sequence[np.ndarray]\n",
"\n",
"class DynamicGraphTemporalSignal(object):\n",
" r\"\"\"A data iterator object to contain a dynamic graph with a\n",
" changing edge set and weights . The feature set and node labels\n",
" (target) are also dynamic. The iterator returns a single discrete temporal\n",
" snapshot for a time period (e.g. day or week). This single snapshot is a\n",
" Pytorch Geometric Data object. Between two temporal snapshots the edges,\n",
" edge weights, target matrices and optionally passed attributes might change.\n",
"\n",
" Args:\n",
" edge_indices (Sequence of Numpy arrays): Sequence of edge index tensors.\n",
" edge_weights (Sequence of Numpy arrays): Sequence of edge weight tensors.\n",
" features (Sequence of Numpy arrays): Sequence of node feature tensors.\n",
" targets (Sequence of Numpy arrays): Sequence of node label (target) tensors.\n",
" **kwargs (optional Sequence of Numpy arrays): Sequence of additional attributes.\n",
" \"\"\"\n",
"\n",
" def __init__(\n",
" self,\n",
" edge_indices: Edge_Indices,\n",
" edge_weights: Edge_Weights,\n",
" features: Node_Features,\n",
" targets: Targets,\n",
" **kwargs: Additional_Features\n",
" ):\n",
" self.edge_indices = edge_indices\n",
" self.edge_weights = edge_weights\n",
" self.features = features\n",
" self.targets = targets\n",
" self.additional_feature_keys = []\n",
" for key, value in kwargs.items():\n",
" setattr(self, key, value)\n",
" self.additional_feature_keys.append(key)\n",
" self._check_temporal_consistency()\n",
" self._set_snapshot_count()\n",
"\n",
" def _check_temporal_consistency(self):\n",
" assert len(self.features) == len(\n",
" self.targets\n",
" ), \"Temporal dimension inconsistency.\"\n",
" assert len(self.edge_indices) == len(\n",
" self.edge_weights\n",
" ), \"Temporal dimension inconsistency.\"\n",
" assert len(self.features) == len(\n",
" self.edge_weights\n",
" ), \"Temporal dimension inconsistency.\"\n",
" for key in self.additional_feature_keys:\n",
" assert len(self.targets) == len(\n",
" getattr(self, key)\n",
" ), \"Temporal dimension inconsistency.\"\n",
"\n",
" def _set_snapshot_count(self):\n",
" self.snapshot_count = len(self.features)\n",
"\n",
" def _get_edge_index(self, time_index: int):\n",
" if self.edge_indices[time_index] is None:\n",
" return self.edge_indices[time_index]\n",
" else:\n",
" return torch.LongTensor(self.edge_indices[time_index])\n",
"\n",
" def _get_edge_weight(self, time_index: int):\n",
" if self.edge_weights[time_index] is None:\n",
" return self.edge_weights[time_index]\n",
" else:\n",
" return torch.FloatTensor(self.edge_weights[time_index])\n",
"\n",
" def _get_features(self, time_index: int):\n",
" if self.features[time_index] is None:\n",
" return self.features[time_index]\n",
" else:\n",
" return torch.FloatTensor(self.features[time_index])\n",
"\n",
" def _get_target(self, time_index: int):\n",
" if self.targets[time_index] is None:\n",
" return self.targets[time_index]\n",
" else:\n",
" if self.targets[time_index].dtype.kind == \"i\":\n",
" return torch.LongTensor(self.targets[time_index])\n",
" elif self.targets[time_index].dtype.kind == \"f\":\n",
" return torch.FloatTensor(self.targets[time_index])\n",
"\n",
" def _get_additional_feature(self, time_index: int, feature_key: str):\n",
" feature = getattr(self, feature_key)[time_index]\n",
" if feature.dtype.kind == \"i\":\n",
" return torch.LongTensor(feature)\n",
" elif feature.dtype.kind == \"f\":\n",
" return torch.FloatTensor(feature)\n",
"\n",
" def _get_additional_features(self, time_index: int):\n",
" additional_features = {\n",
" key: self._get_additional_feature(time_index, key)\n",
" for key in self.additional_feature_keys\n",
" }\n",
" return additional_features\n",
"\n",
" def __getitem__(self, time_index: Union[int, slice]):\n",
" if isinstance(time_index, slice):\n",
" snapshot = DynamicGraphTemporalSignal(\n",
" self.edge_indices[time_index],\n",
" self.edge_weights[time_index],\n",
" self.features[time_index],\n",
" self.targets[time_index],\n",
" **{key: getattr(self, key)[time_index] for key in self.additional_feature_keys}\n",
" )\n",
" else:\n",
" x = self._get_features(time_index)\n",
" edge_index = self._get_edge_index(time_index)\n",
" edge_weight = self._get_edge_weight(time_index)\n",
" y = self._get_target(time_index)\n",
" additional_features = self._get_additional_features(time_index)\n",
"\n",
" snapshot = Data(x=x, edge_index=edge_index, edge_attr=edge_weight,\n",
" y=y, **additional_features)\n",
" return snapshot\n",
"\n",
" def __next__(self):\n",
" if self.t < len(self.features):\n",
" snapshot = self[self.t]\n",
" self.t = self.t + 1\n",
" return snapshot\n",
" else:\n",
" self.t = 0\n",
" raise StopIteration\n",
"\n",
" def __iter__(self):\n",
" self.t = 0\n",
" return self"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9D-3UwVo0dFG"
},
"outputs": [],
"source": [
"# Copied from PyG temporal rather than imported because the library has dependency issues with PyG\n",
"# https://pytorch-geometric-temporal.readthedocs.io/en/latest/_modules/torch_geometric_temporal/signal/static_graph_temporal_signal.html#StaticGraphTemporalSignal\n",
"import torch\n",
"import numpy as np\n",
"from typing import Sequence, Union\n",
"from torch_geometric.data import Data\n",
"\n",
"\n",
"Edge_Index = Union[np.ndarray, None]\n",
"Edge_Weight = Union[np.ndarray, None]\n",
"Node_Features = Sequence[Union[np.ndarray, None]]\n",
"Targets = Sequence[Union[np.ndarray, None]]\n",
"Additional_Features = Sequence[np.ndarray]\n",
"\n",
"class StaticGraphTemporalSignal(object):\n",
" r\"\"\"A data iterator object to contain a static graph with a dynamically\n",
" changing constant time difference temporal feature set (multiple signals).\n",
" The node labels (target) are also temporal. The iterator returns a single\n",
" constant time difference temporal snapshot for a time period (e.g. day or week).\n",
" This single temporal snapshot is a Pytorch Geometric Data object. Between two\n",
" temporal snapshots the features and optionally passed attributes might change.\n",
" However, the underlying graph is the same.\n",
"\n",
" Args:\n",
" edge_index (Numpy array): Index tensor of edges.\n",
" edge_weight (Numpy array): Edge weight tensor.\n",
" features (Sequence of Numpy arrays): Sequence of node feature tensors.\n",
" targets (Sequence of Numpy arrays): Sequence of node label (target) tensors.\n",
" **kwargs (optional Sequence of Numpy arrays): Sequence of additional attributes.\n",
" \"\"\"\n",
"\n",
" def __init__(\n",
" self,\n",
" edge_index: Edge_Index,\n",
" edge_weight: Edge_Weight,\n",
" features: Node_Features,\n",
" targets: Targets,\n",
" **kwargs: Additional_Features\n",
" ):\n",
" self.edge_index = edge_index\n",
" self.edge_weight = edge_weight\n",
" self.features = features\n",
" self.targets = targets\n",
" self.additional_feature_keys = []\n",
" for key, value in kwargs.items():\n",
" setattr(self, key, value)\n",
" self.additional_feature_keys.append(key)\n",
" self._check_temporal_consistency()\n",
" self._set_snapshot_count()\n",
"\n",
" def _check_temporal_consistency(self):\n",
" assert len(self.features) == len(\n",
" self.targets\n",
" ), \"Temporal dimension inconsistency.\"\n",
" for key in self.additional_feature_keys:\n",
" assert len(self.targets) == len(\n",
" getattr(self, key)\n",
" ), \"Temporal dimension inconsistency.\"\n",
"\n",
" def _set_snapshot_count(self):\n",
" self.snapshot_count = len(self.features)\n",
"\n",
" def _get_edge_index(self):\n",
" if self.edge_index is None:\n",
" return self.edge_index\n",
" else:\n",
" return torch.LongTensor(self.edge_index)\n",
"\n",
" def _get_edge_weight(self):\n",
" if self.edge_weight is None:\n",
" return self.edge_weight\n",
" else:\n",
" return torch.FloatTensor(self.edge_weight)\n",
"\n",
" def _get_features(self, time_index: int):\n",
" if self.features[time_index] is None:\n",
" return self.features[time_index]\n",
" else:\n",
" return torch.FloatTensor(self.features[time_index])\n",
"\n",
" def _get_target(self, time_index: int):\n",
" if self.targets[time_index] is None:\n",
" return self.targets[time_index]\n",
" else:\n",
" if self.targets[time_index].dtype.kind == \"i\":\n",
" return torch.LongTensor(self.targets[time_index])\n",
" elif self.targets[time_index].dtype.kind == \"f\":\n",
" return torch.FloatTensor(self.targets[time_index])\n",
"\n",
" def _get_additional_feature(self, time_index: int, feature_key: str):\n",
" feature = getattr(self, feature_key)[time_index]\n",
" if feature.dtype.kind == \"i\":\n",
" return torch.LongTensor(feature)\n",
" elif feature.dtype.kind == \"f\":\n",
" return torch.FloatTensor(feature)\n",
"\n",
" def _get_additional_features(self, time_index: int):\n",
" additional_features = {\n",
" key: self._get_additional_feature(time_index, key)\n",
" for key in self.additional_feature_keys\n",
" }\n",
" return additional_features\n",
"\n",
" def __getitem__(self, time_index: Union[int, slice]):\n",
" if isinstance(time_index, slice):\n",
" snapshot = StaticGraphTemporalSignal(\n",
" self.edge_index,\n",
" self.edge_weight,\n",
" self.features[time_index],\n",
" self.targets[time_index],\n",
" **{key: getattr(self, key)[time_index] for key in self.additional_feature_keys}\n",
" )\n",
" else:\n",
" x = self._get_features(time_index)\n",
" edge_index = self._get_edge_index()\n",
" edge_weight = self._get_edge_weight()\n",
" y = self._get_target(time_index)\n",
" additional_features = self._get_additional_features(time_index)\n",
"\n",
" snapshot = Data(x=x, edge_index=edge_index, edge_attr=edge_weight,\n",
" y=y, **additional_features)\n",
" return snapshot\n",
"\n",
" def __next__(self):\n",
" if self.t < len(self.features):\n",
" snapshot = self[self.t]\n",
" self.t = self.t + 1\n",
" return snapshot\n",
" else:\n",
" self.t = 0\n",
" raise StopIteration\n",
"\n",
" def __iter__(self):\n",
" self.t = 0\n",
" return self"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "PGE29GpL0dFH"
},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import torch\n",
"from typing import Union\n",
"import glob\n",
"from natsort import natsorted\n",
"import random\n",
"\n",
"# Fix random seed for ease of reproduction\n",
"seed = 42\n",
"random.seed(seed)\n",
"torch.manual_seed(seed)\n",
"\n",
"# Dataset loader for SP500 stock prices\n",
"class SP500CorrelationsDatasetLoader(object):\n",
" def __init__(self, corr_name, corr_scope):\n",
" self._read_csv(corr_name, corr_scope)\n",
"\n",
" # Load a global correlation under the name corr_name\n",
" def _load_global_corr(self, corr_name):\n",
" return np.loadtxt(f'{workdir}/{corr_name}/global_corr.csv', delimiter=',')\n",
"\n",
" # Load a local correlation under the name corr_name\n",
" def _load_local_corrs(self, corr_name):\n",
" _correlation_matrices = []\n",
" corr_files = natsorted(glob.glob(f'{workdir}/{corr_name}/local_corr_*.csv'))\n",
" for corr_file in corr_files:\n",
" matrix = np.loadtxt(corr_file, delimiter=',')\n",
" _correlation_matrices.append(matrix)\n",
" return _correlation_matrices\n",
"\n",
" # Helper function for reading a correlation with type corr_name and scope corr_scope from CSV file\n",
" def _read_csv(self, corr_name, corr_scope):\n",
" match corr_scope:\n",
" case 'global':\n",
" self._correlation_matrices = [self._load_global_corr(corr_name)]\n",
" case 'local':\n",
" self._correlation_matrices = self._load_local_corrs(corr_name)\n",
" case 'dual':\n",
" # Stack global and local correlation matrices for dual correlation\n",
" global_corr = self._load_global_corr(corr_name)\n",
" self._correlation_matrices = [np.stack((global_corr, local_corr), axis=-1) for local_corr in self._load_local_corrs(corr_name)]\n",
" case None:\n",
" # None uses identity matrix as correlation\n",
" # Infer dimension from a global correlation matrix\n",
" global_corr = self._load_global_corr('pcc')\n",
" self._correlation_matrices = [np.eye(global_corr.shape[0], global_corr.shape[1])]\n",
"\n",
" if corr_name == 'mi':\n",
" # Normalize MI to [0, 1]\n",
" max_mi = 0\n",
" for matrix in self._correlation_matrices:\n",
" max_mi = max(np.max(matrix), max_mi)\n",
" # MI shouldn't be negative\n",
" matrix[matrix < 0] = 0\n",
" for matrix in self._correlation_matrices:\n",
" matrix = matrix / max_mi\n",
"\n",
" df = pd.read_csv(f'{workdir}/sp500.csv')\n",
" df = df.set_index('Date')\n",
" data = torch.from_numpy(df.to_numpy()).to(torch.float32)\n",
"\n",
" # Round data size to nearest multiple of batch_size\n",
" self.days_in_quarter = 64\n",
" num_quarters = data.size(0) // self.days_in_quarter\n",
" num_days = num_quarters * self.days_in_quarter\n",
" data = data[:num_days]\n",
"\n",
" # z-score normalization with training data following GERU\n",
" train_days = int(0.8 * num_quarters) * self.days_in_quarter\n",
" data = (data - data[:train_days].mean(dim=0)) / data[:train_days].std(dim=0)\n",
" data = data.numpy()\n",
"\n",
" data = data[..., np.newaxis]\n",
"\n",
" assert(not np.any(np.isnan(data)))\n",
" self._dataset = data\n",
"\n",
" def _get_edges(self, times, overlap):\n",
" # Construct a fully-connected graph\n",
" def helper(corr_index):\n",
" return np.array(np.ones(self._correlation_matrices[corr_index].shape[:2]).nonzero())\n",
"\n",
" if len(self._correlation_matrices) == 1:\n",
" _edges = helper(0)\n",
" else:\n",
" _edges = []\n",
" for time in range(0, self._dataset.shape[0] - self.batch_size, overlap):\n",
" if not time in times:\n",
" continue\n",
" corr_index = max(0, time // self.days_in_quarter - 1)\n",
" _edges.append(\n",
" helper(corr_index)\n",
" )\n",
" return _edges\n",
"\n",
" def _get_edge_weights(self, times, overlap):\n",
" # Edge weights are the correlations between stocks\n",
" def helper(corr_index):\n",
" w = self._correlation_matrices[corr_index]\n",
" # Flatten the first two dimensions\n",
" return w.reshape((w.shape[0] * w.shape[1],) + w.shape[2:])\n",
"\n",
" if len(self._correlation_matrices) == 1:\n",
" _edge_weights = helper(0)\n",
" else:\n",
" _edge_weights = []\n",
" for time in range(0, self._dataset.shape[0] - self.batch_size, overlap):\n",
" if not time in times:\n",
" continue\n",
" corr_index = max(0, time // self.days_in_quarter - 1)\n",
" _edge_weights.append(\n",
" helper(corr_index)\n",
" )\n",
" return _edge_weights\n",
"\n",
" def _get_targets_and_features(self, times, overlap, predict_all):\n",
" # Given previous batch_size stock prices...\n",
" features = [\n",
" self._dataset[i : i + self.batch_size, :]\n",
" for i in range(0, self._dataset.shape[0] - self.batch_size, overlap)\n",
" if i in times\n",
" ]\n",
" # predict next-day stock prices\n",
" targets = [\n",
" (self._dataset[i+1 : i + self.batch_size+1, :, 0]).T if predict_all else (self._dataset[i + self.batch_size, :, 0]).T\n",
" for i in range(0, self._dataset.shape[0] - self.batch_size, overlap)\n",
" if i in times\n",
" ]\n",
" return features, targets\n",
"\n",
" def get_dataset(self, batch_size, split) -> Union[StaticGraphTemporalSignal, DynamicGraphTemporalSignal]:\n",
" # Returning the data iterator where the train is designed for many-to-many predictions (each day predict next day's price)\n",
" # while the validation and test are many-to-one predictions (many past days predict tomorrow's price)\n",
"\n",
" self.batch_size = batch_size\n",
"\n",
" total_times = list(range(0, self._dataset.shape[0] - self.batch_size, self.batch_size))\n",
"\n",
" # We do a 8-1-1 split for train-validation-test. Since the test set is one year apart from training,\n",
" # It is much more challenging to predict.\n",
" if split == 'train':\n",
" times = list(range(total_times[int(len(total_times) * 0)], total_times[int(len(total_times) * 0.8)]))\n",
" overlap = self.batch_size\n",
" predict_all = True\n",
" elif split == 'val':\n",
" times = list(range(total_times[int(len(total_times) * 0.8)], total_times[int(len(total_times) * 0.9)]))\n",
" overlap = 1\n",
" predict_all = False\n",
" elif split == 'test':\n",
" times = list(range(total_times[int(len(total_times) * 0.9)], total_times[-1] + self.batch_size))\n",
" overlap = 1\n",
" predict_all = False\n",
" else:\n",
" raise ValueError(f'Invalid split name: {split}')\n",
"\n",
" _edges = self._get_edges(times, overlap)\n",
" _edge_weights = self._get_edge_weights(times, overlap)\n",
" features, targets = self._get_targets_and_features(times, overlap, predict_all)\n",
" dataset = (DynamicGraphTemporalSignal if type(_edges) == list else StaticGraphTemporalSignal)(\n",
" _edges, _edge_weights, features, targets\n",
" )\n",
" return dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-9rcAGsl0dFH"
},
"outputs": [],
"source": [
"# Helper function to get the dataset for a correlation\n",
"def get_dataset(corr_name, corr_scope):\n",
" loader = SP500CorrelationsDatasetLoader(corr_name=corr_name, corr_scope=corr_scope)\n",
"\n",
" lag_size = 64\n",
" # Train dataset has double the batch_size because it's trained under many-to-many prediction.\n",
" # During test time, the model is used for many-to-one prediction given batch_size previous days.\n",
" # Hence, we need to have a larger training batch_size than the lag_size during test.\n",
" train_dataset = loader.get_dataset(batch_size=lag_size * 2, split='train')\n",
" val_dataset = loader.get_dataset(batch_size=lag_size, split='val')\n",
" test_dataset = loader.get_dataset(batch_size=lag_size, split='test')\n",
"\n",
" train_samples = list(train_dataset)\n",
" val_samples = list(val_dataset)\n",
" test_samples = list(test_dataset)\n",
"\n",
" return {\n",
" 'train_samples': train_samples,\n",
" 'val_samples': val_samples,\n",
" 'test_samples': test_samples,\n",
" }"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "93mt0lVh0dFH"
},
"source": [
"# Differential Graph Transformer"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9-REadVV0dFH"
},
"outputs": [],
"source": [
"# Adapted from reference implementation of Differential Transformer, included an optional A input to MultiheadDiffAttn.forward()\n",
"# https://github.com/microsoft/unilm/blob/master/Diff-Transformer/multihead_diffattn.py\n",
"\n",
"import math\n",
"import torch\n",
"import torch.nn.functional as F\n",
"from torch import nn\n",
"\n",
"class RMSNorm(nn.Module):\n",
" def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine=True):\n",
" super().__init__()\n",
" self.dim = dim\n",
" self.eps = eps\n",
" self.elementwise_affine = elementwise_affine\n",
" if self.elementwise_affine:\n",
" self.weight = nn.Parameter(torch.ones(dim))\n",
" else:\n",
" self.register_parameter('weight', None)\n",
"\n",
" def _norm(self, x):\n",
" return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)\n",
"\n",
" def forward(self, x):\n",
" output = self._norm(x.float()).type_as(x)\n",
" if self.weight is not None:\n",
" output = output * self.weight\n",
" return output\n",
"\n",
" def extra_repr(self) -> str:\n",
" return f'dim={self.dim}, eps={self.eps}, elementwise_affine={self.elementwise_affine}'\n",
"\n",
"\n",
"def init_method(tensor, **kwargs):\n",
" nn.init.kaiming_uniform_(tensor, a=math.sqrt(5))\n",
"\n",
"def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:\n",
" \"\"\"torch.repeat_interleave(x, dim=1, repeats=n_rep)\"\"\"\n",
" bs, n_kv_heads, slen, head_dim = x.shape\n",
" if n_rep == 1:\n",
" return x\n",
" return (\n",
" x[:, :, None, :, :]\n",
" .expand(bs, n_kv_heads, n_rep, slen, head_dim)\n",
" .reshape(bs, n_kv_heads * n_rep, slen, head_dim)\n",
" )\n",
"\n",
"def lambda_init_fn(depth):\n",
" return 0.8 - 0.6 * math.exp(-0.3 * depth)\n",
"\n",
"\n",
"# Differential Graph Attention with multiple heads\n",
"class MultiheadDiffAttn(nn.Module):\n",
" def __init__(\n",
" self,\n",
" embed_dim,\n",
" depth,\n",
" num_heads,\n",
" ):\n",
" super().__init__()\n",
" self.embed_dim = embed_dim\n",
" # num_heads set to half of Transformer's #heads\n",
" self.num_heads = num_heads\n",
" self.num_kv_heads = num_heads\n",
" self.n_rep = self.num_heads // self.num_kv_heads\n",
"\n",
" self.head_dim = embed_dim // num_heads // 2\n",
" self.scaling = self.head_dim ** -0.5\n",
"\n",
" self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False)\n",
" self.k_proj = nn.Linear(embed_dim, embed_dim // self.n_rep, bias=False)\n",
" self.v_proj = nn.Linear(embed_dim, embed_dim // self.n_rep, bias=False)\n",
"\n",
" self.lambda_init = lambda_init_fn(depth)\n",
" self.lambda_q1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))\n",
" self.lambda_k1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))\n",
" self.lambda_q2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))\n",
" self.lambda_k2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))\n",
"\n",
" self.subln = RMSNorm(2 * self.head_dim, eps=1e-5, elementwise_affine=True)\n",
"\n",
" def forward(\n",
" self,\n",
" x,\n",
" A=None,\n",
" attn_mask=None,\n",
" ):\n",
" bsz, tgt_len, embed_dim = x.size()\n",
" src_len = tgt_len\n",
"\n",
" # Project input x into query, key, and value\n",
" q = self.q_proj(x)\n",
" k = self.k_proj(x)\n",
" v = self.v_proj(x)\n",
"\n",
" q = q.view(bsz, tgt_len, 2 * self.num_heads, self.head_dim)\n",
" k = k.view(bsz, src_len, 2 * self.num_kv_heads, self.head_dim)\n",
" v = v.view(bsz, src_len, self.num_kv_heads, 2 * self.head_dim)\n",
"\n",
" q = q.transpose(1, 2)\n",
" k = repeat_kv(k.transpose(1, 2), self.n_rep)\n",
" v = repeat_kv(v.transpose(1, 2), self.n_rep)\n",
" q *= self.scaling\n",
"\n",
" # Compute attention weights by multiplying query and key\n",
" attn_weights = torch.matmul(q, k.transpose(-1, -2))\n",
" attn_weights = torch.nan_to_num(attn_weights)\n",
" # Apply attention mask\n",
" if attn_mask is not None:\n",
" attn_weights += attn_mask\n",
" # Calculate attention scores using softmax\n",
" attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as(\n",
" attn_weights\n",
" )\n",
"\n",
" # Calculate the lambda used for differential attention\n",
" lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()).type_as(q)\n",
" lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()).type_as(q)\n",
" lambda_full = lambda_1 - lambda_2 + self.lambda_init\n",
"\n",
" # Optionally condition the differential attention on a graph prior A\n",
" attn_weights = attn_weights.view(bsz, self.num_heads, 2, tgt_len, src_len)\n",
" attn_weights = attn_weights[:, :, 0] * (1 if A is None else A) - lambda_full * attn_weights[:, :, 1]\n",
"\n",
" # Compute output embeddings by mixing values based on their attention scores\n",
" attn = torch.matmul(attn_weights, v)\n",
" attn = self.subln(attn)\n",
" attn = attn * (1 - self.lambda_init)\n",
" attn = attn.transpose(1, 2).reshape(bsz, tgt_len, self.num_heads * 2 * self.head_dim)\n",
" return (attn, attn_weights)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "K0G6q4eQ0dFH"
},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"from torch_geometric.utils import to_dense_adj\n",
"\n",
"# A normal feedforward layer\n",
"class FeedForward(nn.Module):\n",
" def __init__(self, hidden_size, expand_ratio, dropout):\n",
" super(FeedForward, self).__init__()\n",
" self.linear = nn.Linear(hidden_size, hidden_size * expand_ratio)\n",
" self.linear2 = nn.Linear(hidden_size * expand_ratio, hidden_size)\n",
" self.relu = nn.ReLU()\n",
" self.dropout = nn.Dropout(p=dropout)\n",
"\n",
" def forward(self, x):\n",
" x = self.linear(x)\n",
" x = self.relu(x)\n",
" x = self.linear2(x)\n",
" x = self.dropout(x)\n",
" return x\n",
"\n",
"# Wrapper code for MultiheadDiffAttn with layer norm and feedforward after the attention\n",
"class Attention(nn.Module):\n",
" def __init__(self, d_model, num_heads, expand_ratio, dropout):\n",
" super().__init__()\n",
" self.mha = MultiheadDiffAttn(embed_dim=d_model, num_heads=num_heads, depth=0)\n",
" self.ln2 = nn.LayerNorm(d_model)\n",
" self.ffn = FeedForward(hidden_size=d_model, expand_ratio=expand_ratio, dropout=dropout)\n",
"\n",
" def forward(self, x, A=None, attn_mask=None, need_weights=False):\n",
" x1, attn_weights = self.mha(x, A, attn_mask=attn_mask)\n",
" x = self.ln2(self.ffn(x1) + x1)\n",
" if need_weights:\n",
" return (x, attn_weights)\n",
" else:\n",
" return x\n",
"\n",
"# Differential Graph Transformer = temporal attention + spatial attention\n",
"# Spatial attention may optionally receive an adjacency matrix for conditioning.\n",
"class DGT(nn.Module):\n",
" def __init__(self, in_channels=1, out_channels=32, num_heads=2, num_layers=2, expand_ratio=1, dropout=0.1, T=128, N=472, use_spatial=True):\n",
" super().__init__()\n",
" self.T = T\n",
" self.N = N\n",
" self.d_model = out_channels\n",
" self.num_heads = num_heads\n",
" self.num_layers = num_layers\n",
" self.input_proj = nn.Linear(in_channels, out_channels)\n",
" self.time_embedding = nn.Embedding(T, out_channels)\n",
" self.stock_embedding = nn.Embedding(N, out_channels)\n",
" self.use_spatial = use_spatial\n",
" if use_spatial:\n",
" self.spatial_attns = nn.ModuleList([Attention(out_channels, num_heads, expand_ratio, dropout) for _ in range(num_layers)])\n",
" self.temporal_attns = nn.ModuleList([Attention(out_channels, num_heads, expand_ratio, dropout) for _ in range(num_layers)])\n",
"\n",
" def forward(self, x, edge_index, edge_weight, need_weights=False):\n",
" N, T, D = x.size()\n",
" assert(D == 1)\n",
" assert(T <= self.T and N == self.N)\n",
"\n",
" # Compute initial node embedding for the graph transformer\n",
" # Node embedding incorporates current stock prices, stock embeddings, and time embeddings.\n",
" x = x.permute(1, 0, 2) # T, N, D\n",
" x = self.input_proj(x)\n",
" stock_embs = self.stock_embedding(torch.arange(N).unsqueeze(0).expand(T, N).to(x.device))\n",
" x += stock_embs\n",
" time_embs = self.time_embedding(torch.arange(T).unsqueeze(0).expand(N, T).to(x.device))\n",
" x += time_embs.permute(1, 0, 2) # T, N, D\n",
"\n",
" x = x.permute(1, 0, 2) # N, T, D\n",
"\n",
" # Iterate through each layer of DGT\n",
" for i in range(self.num_layers):\n",
" # First apply temporal attention to learn temporal dependencies\n",
" temporal_causal_mask = torch.triu(\n",
" torch.zeros([T, T])\n",
" .float()\n",
" .fill_(float(\"-inf\")),\n",
" 1,\n",
" ).expand(N, self.num_heads*2, T, T).to(x.device)\n",
" x = self.temporal_attns[i](x, attn_mask=temporal_causal_mask, need_weights=need_weights) + x\n",
"\n",
" # Next apply spatial attention (aka differential graph attention) to learn interstock relations\n",
" if self.use_spatial:\n",
" x = x.permute(1, 0, 2) # T, N, D\n",
" A = to_dense_adj(edge_index, edge_attr=edge_weight)\n",
" # Encountered more than one adjacency matrices, e.g. dual correlations\n",
" if len(A.size()) == 4:\n",
" A = A.reshape(A.size(-1), A.size(1), A.size(2))\n",
" x = self.spatial_attns[i](x, A, need_weights=need_weights) + x\n",
" x = x.permute(1, 0, 2) # N, T, D\n",
"\n",
" return x"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "A5zY_vud0dFI"
},
"source": [
"# GRU Baseline"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "y9rQkiLU0dFI"
},
"outputs": [],
"source": [
"# Gated Recurrent Unit baseline for comparison\n",
"class GRU(torch.nn.Module):\n",
" def __init__(self, in_channels: int, out_channels: int, num_layers: int):\n",
" super(GRU, self).__init__()\n",
" self.rnn = nn.GRU(input_size=in_channels, hidden_size=out_channels, num_layers=num_layers, batch_first=True)\n",
"\n",
" def forward(\n",
" self,\n",
" x: torch.FloatTensor,\n",
" edge_index: torch.LongTensor,\n",
" edge_weight: torch.FloatTensor = None,\n",
" ) -> torch.FloatTensor:\n",
" outputs, _ = self.rnn(x)\n",
" return outputs"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "l__wiYS40dFI"
},
"source": [
"# Driver"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "shaDjKvu0dFI"
},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn.functional as F\n",
"\n",
"# Common drive for all models\n",
"class Driver(torch.nn.Module):\n",
" def __init__(self, gnn, corr_name, corr_scope, node_features, hidden_size=32, **kwargs):\n",
" super(Driver, self).__init__()\n",
" self.recurrent = gnn(in_channels=node_features, out_channels=hidden_size, **kwargs)\n",
" self.linear = torch.nn.Linear(hidden_size, 1)\n",
" self.corr_name = corr_name\n",
" self.corr_scope = corr_scope\n",
"\n",
" # Run on the provided graph (specified with edge_index and edge_weight) and temporal signal x (past stock prices)\n",
" def forward(self, x, edge_index, edge_weight, hidden=None):\n",
" device = self.model_device()\n",
" if hidden is None:\n",
" outputs = self.recurrent(x.to(device), edge_index.to(device), edge_weight.to(device))\n",
" else:\n",
" outputs = self.recurrent(x.to(device), edge_index.to(device), edge_weight.to(device), hidden)\n",
" # Use final linear layer for regression\n",
" return self.linear(F.relu(outputs)), outputs\n",
"\n",
" # Get the model name for display and saving model weights\n",
" def model_name(self):\n",
" arch = self.model_arch()\n",
" if arch == 'GRU':\n",
" return f'{arch}'\n",
" elif arch == 'DGT':\n",
" name = f'{arch}{\"\" if self.recurrent.use_spatial else \"_no_spatial\"}'\n",
" if self.corr_scope is not None:\n",
" name += f'_{self.corr_name}_{self.corr_scope}'\n",
" return name\n",
"\n",
" # Get the model architecture for display\n",
" def model_arch(self):\n",
" return self.recurrent.__class__.__name__\n",
"\n",
" # Set the device for the model\n",
" def model_device(self):\n",
" return torch.device(\"cuda\" if torch.cuda.is_available() else \"mps\" if torch.backends.mps.is_available() else \"cpu\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XgwYifus0dFI"
},
"source": [
"# Evaluation on Price Regression with RMSE and MAE"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "yI5sg5TL0dFI"
},
"outputs": [],
"source": [
"import math\n",
"\n",
"# Root Mean Squared Error for evaluation\n",
"def rmse(y_hat, y):\n",
" return math.sqrt(F.mse_loss(y_hat, y).item())\n",
"\n",
"# Mean Absolute Error for evaluation\n",
"def mae(y_hat, y):\n",
" return F.l1_loss(y_hat, y).item()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "L7FQLMzi0dFI"
},
"outputs": [],
"source": [
"import wandb\n",
"\n",
"# Helper function for inference\n",
"def infer(model, snapshot):\n",
" X = snapshot.x\n",
" batch_y_hats, _ = model(X.transpose(0, 1), snapshot.edge_index, snapshot.edge_attr)\n",
" return batch_y_hats[:, -1]\n",
"\n",
"# Evaluate the model on eval_dataset and calculate RMSE and MAE\n",
"def eval(model, eval_dataset):\n",
" model.eval()\n",
" with torch.no_grad():\n",
" y_hats = list(map(lambda snapshot: infer(model, snapshot), eval_dataset))\n",
" ys = [snapshot.y for snapshot in eval_dataset]\n",
" y_hats = torch.stack(y_hats, dim=0).squeeze().to(model.model_device())\n",
" ys = torch.stack(ys, dim=0).to(model.model_device())\n",
" eval_rmse = rmse(y_hats, ys)\n",
" eval_mae = mae(y_hats, ys)\n",
" return {'y_hats': y_hats, 'ys': ys, 'rmse': eval_rmse, 'mae': eval_mae}"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "W_Q0NEde0dFI"
},
"outputs": [],
"source": [
"# Helper function to get a model based on the input configs, optionally loads the weight if load_weights=True\n",
"def get_model(gnn, use_spatial, corr_name, corr_scope, lr, load_weights=False):\n",
" node_features = 1\n",
" if gnn == DGT:\n",
" model = Driver(gnn, corr_name, corr_scope, node_features, num_heads=2, use_spatial=use_spatial)\n",
" elif gnn == GRU:\n",
" # GRU does not support any correlation\n",
" if corr_name != None or corr_scope != None:\n",
" return None\n",
" model = Driver(gnn, None, None, node_features, num_layers=2)\n",
" if load_weights:\n",
" model.load_state_dict(torch.load(f'{workdir}/models/{model.model_name()}_lr_{lr}.pth', weights_only=True))\n",
" return model.to(model.model_device())\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "I2ICXl7m0dFI"
},
"source": [
"# Training"
]
},
{
"cell_type": "markdown",
"source": [
"**Note: Training took 3 hours on a T4. You can skip the following code block and run the evaluations directly as our checkpoints are already downloaded.** In case you are training, you can also set `track_with_wandb` to `False` if you don't want to track with Weights and Biases."
],
"metadata": {
"id": "SOAe947ephXR"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"id": "cYjWdVjJ0dFI",
"outputId": "dfa35f72-50ba-402d-c6df-0688f52a32ca"
},
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.Javascript object>"
],
"application/javascript": [
"\n",
" window._wandbApiKey = new Promise((resolve, reject) => {\n",
" function loadScript(url) {\n",
" return new Promise(function(resolve, reject) {\n",
" let newScript = document.createElement(\"script\");\n",
" newScript.onerror = reject;\n",
" newScript.onload = resolve;\n",
" document.body.appendChild(newScript);\n",
" newScript.src = url;\n",
" });\n",
" }\n",
" loadScript(\"https://cdn.jsdelivr.net/npm/postmate/build/postmate.min.js\").then(() => {\n",
" const iframe = document.createElement('iframe')\n",
" iframe.style.cssText = \"width:0;height:0;border:none\"\n",
" document.body.appendChild(iframe)\n",
" const handshake = new Postmate({\n",
" container: iframe,\n",
" url: 'https://wandb.ai/authorize'\n",
" });\n",
" const timeout = setTimeout(() => reject(\"Couldn't auto authenticate\"), 5000)\n",
" handshake.then(function(child) {\n",
" child.on('authorize', data => {\n",
" clearTimeout(timeout)\n",
" resolve(data)\n",
" });\n",
" });\n",
" })\n",
" });\n",
" "
]
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"\u001b[34m\u001b[1mwandb\u001b[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mkevinxli\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Tracking run with wandb version 0.19.2"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Run data is saved locally in <code>/content/wandb/run-20250119_200232-5lp1cu5y</code>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Syncing run <strong><a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/5lp1cu5y' target=\"_blank\">GRU_lr_0.01</a></strong> to <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/developer-guide' target=\"_blank\">docs</a>)<br>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
" View project at <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction' target=\"_blank\">https://wandb.ai/kevinxli/cs224w-stock-market-prediction</a>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
" View run at <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/5lp1cu5y' target=\"_blank\">https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/5lp1cu5y</a>"
]
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"GRU epoch 0 val/rmse: 1.0151258562542624 val/mae: 0.5934833884239197\n",
"GRU epoch 10 val/rmse: 0.7394664072011641 val/mae: 0.24550306797027588\n",
"GRU epoch 20 val/rmse: 0.7674437349427956 val/mae: 0.2417578399181366\n",
"GRU epoch 30 val/rmse: 0.6855584602789891 val/mae: 0.13805679976940155\n",
"GRU epoch 40 val/rmse: 0.6683109812157823 val/mae: 0.1362270712852478\n",
"GRU epoch 50 val/rmse: 0.6567212184121727 val/mae: 0.12936756014823914\n",
"GRU epoch 60 val/rmse: 0.6512660415562784 val/mae: 0.1538236290216446\n",
"GRU epoch 70 val/rmse: 0.6393448902919666 val/mae: 0.1312980055809021\n",
"GRU epoch 80 val/rmse: 0.6378874321394541 val/mae: 0.14947465062141418\n",
"GRU epoch 90 val/rmse: 0.6279125062956136 val/mae: 0.13595153391361237\n",
"GRU epoch 99 val/rmse: 0.6195823465564757 val/mae: 0.11029697954654694\n",
"GRU lr: 0.01 test/rmse: 3.1166079945745024 test/mae: 0.4076187312602997\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": []
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"<br> <style><br> .wandb-row {<br> display: flex;<br> flex-direction: row;<br> flex-wrap: wrap;<br> justify-content: flex-start;<br> width: 100%;<br> }<br> .wandb-col {<br> display: flex;<br> flex-direction: column;<br> flex-basis: 100%;<br> flex: 1;<br> padding: 10px;<br> }<br> </style><br><div class=\"wandb-row\"><div class=\"wandb-col\"><h3>Run history:</h3><br/><table class=\"wandb\"><tr><td>epoch</td><td>▁▁▁▁▂▂▂▂▂▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇██</td></tr><tr><td>step</td><td>▁▃▁▅▂▃▃▃▃▂▇▇▆▃▂▅▆▃▇▆▁▁▅█▄▁▇█▃▇▃█▁▃▇▆▂▇▃█</td></tr><tr><td>test/mae</td><td>▁</td></tr><tr><td>test/rmse</td><td>▁</td></tr><tr><td>train/loss</td><td>█▄▂▁▁▁▁▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁</td></tr><tr><td>val/best_mae</td><td>▁</td></tr><tr><td>val/best_rmse</td><td>▁</td></tr><tr><td>val/mae</td><td>█▃▃▁▁▁▂▁▂▁▁</td></tr><tr><td>val/rmse</td><td>█▃▄▂▂▂▂▁▁▁▁</td></tr></table><br/></div><div class=\"wandb-col\"><h3>Run summary:</h3><br/><table class=\"wandb\"><tr><td>epoch</td><td>99</td></tr><tr><td>step</td><td>14</td></tr><tr><td>test/mae</td><td>0.40762</td></tr><tr><td>test/rmse</td><td>3.11661</td></tr><tr><td>train/loss</td><td>0.01211</td></tr><tr><td>val/best_mae</td><td>0.1103</td></tr><tr><td>val/best_rmse</td><td>0.61958</td></tr><tr><td>val/mae</td><td>0.1103</td></tr><tr><td>val/rmse</td><td>0.61958</td></tr></table><br/></div></div>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
" View run <strong style=\"color:#cdcd00\">GRU_lr_0.01</strong> at: <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/5lp1cu5y' target=\"_blank\">https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/5lp1cu5y</a><br> View project at: <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction' target=\"_blank\">https://wandb.ai/kevinxli/cs224w-stock-market-prediction</a><br>Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Find logs at: <code>./wandb/run-20250119_200232-5lp1cu5y/logs</code>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Tracking run with wandb version 0.19.2"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Run data is saved locally in <code>/content/wandb/run-20250119_200257-53m3csrj</code>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Syncing run <strong><a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/53m3csrj' target=\"_blank\">GRU_lr_0.1</a></strong> to <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/developer-guide' target=\"_blank\">docs</a>)<br>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
" View project at <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction' target=\"_blank\">https://wandb.ai/kevinxli/cs224w-stock-market-prediction</a>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
" View run at <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/53m3csrj' target=\"_blank\">https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/53m3csrj</a>"
]
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"GRU epoch 0 val/rmse: 1.147122415157888 val/mae: 0.7995038032531738\n",
"GRU epoch 10 val/rmse: 1.0865229942462133 val/mae: 0.5747919082641602\n",
"GRU epoch 20 val/rmse: 1.1592675603968394 val/mae: 0.7387993931770325\n",
"GRU epoch 30 val/rmse: 1.1101862130896667 val/mae: 0.624571681022644\n",
"GRU epoch 40 val/rmse: 1.0454624382105593 val/mae: 0.5586006045341492\n",
"GRU epoch 50 val/rmse: 0.9844211083164526 val/mae: 0.4778263568878174\n",
"GRU epoch 60 val/rmse: 1.0663733844094228 val/mae: 0.5754823684692383\n",
"GRU epoch 70 val/rmse: 1.0399111386996494 val/mae: 0.5042101740837097\n",
"GRU epoch 80 val/rmse: 1.014945169480599 val/mae: 0.48927828669548035\n",
"GRU epoch 90 val/rmse: 1.0538482717935054 val/mae: 0.5330270528793335\n",
"GRU epoch 99 val/rmse: 0.969964189409132 val/mae: 0.4710395634174347\n",
"GRU lr: 0.1 test/rmse: 3.527882413405376 test/mae: 1.0555206537246704\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": []
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"<br> <style><br> .wandb-row {<br> display: flex;<br> flex-direction: row;<br> flex-wrap: wrap;<br> justify-content: flex-start;<br> width: 100%;<br> }<br> .wandb-col {<br> display: flex;<br> flex-direction: column;<br> flex-basis: 100%;<br> flex: 1;<br> padding: 10px;<br> }<br> </style><br><div class=\"wandb-row\"><div class=\"wandb-col\"><h3>Run history:</h3><br/><table class=\"wandb\"><tr><td>epoch</td><td>▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇████</td></tr><tr><td>step</td><td>▇▅▂▄▁▄▇▂▇█▅▃▁▁▃▃▄▁▃▃▃▁▇▁▅▅█▅▅▆▁▇▆▇▂▅▅▂▄▅</td></tr><tr><td>test/mae</td><td>▁</td></tr><tr><td>test/rmse</td><td>▁</td></tr><tr><td>train/loss</td><td>▆▂▂▅▅█▅▃▅▂▃▄▆▆▂▁▁▂▁▃▁▃▅▂▃▁▂▂▁▁▁▃▁▁▃▃▁▁▁▂</td></tr><tr><td>val/best_mae</td><td>▁</td></tr><tr><td>val/best_rmse</td><td>▁</td></tr><tr><td>val/mae</td><td>█▃▇▄▃▁▃▂▁▂▁</td></tr><tr><td>val/rmse</td><td>█▅█▆▄▂▅▄▃▄▁</td></tr></table><br/></div><div class=\"wandb-col\"><h3>Run summary:</h3><br/><table class=\"wandb\"><tr><td>epoch</td><td>99</td></tr><tr><td>step</td><td>14</td></tr><tr><td>test/mae</td><td>1.05552</td></tr><tr><td>test/rmse</td><td>3.52788</td></tr><tr><td>train/loss</td><td>0.18069</td></tr><tr><td>val/best_mae</td><td>0.47104</td></tr><tr><td>val/best_rmse</td><td>0.96996</td></tr><tr><td>val/mae</td><td>0.47104</td></tr><tr><td>val/rmse</td><td>0.96996</td></tr></table><br/></div></div>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
" View run <strong style=\"color:#cdcd00\">GRU_lr_0.1</strong> at: <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/53m3csrj' target=\"_blank\">https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/53m3csrj</a><br> View project at: <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction' target=\"_blank\">https://wandb.ai/kevinxli/cs224w-stock-market-prediction</a><br>Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Find logs at: <code>./wandb/run-20250119_200257-53m3csrj/logs</code>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Tracking run with wandb version 0.19.2"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Run data is saved locally in <code>/content/wandb/run-20250119_200322-yd1xc1ur</code>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Syncing run <strong><a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/yd1xc1ur' target=\"_blank\">DGT_no_spatial_lr_0.01</a></strong> to <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/developer-guide' target=\"_blank\">docs</a>)<br>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
" View project at <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction' target=\"_blank\">https://wandb.ai/kevinxli/cs224w-stock-market-prediction</a>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
" View run at <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/yd1xc1ur' target=\"_blank\">https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/yd1xc1ur</a>"
]
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"DGT_no_spatial epoch 0 val/rmse: 1.0516432732541026 val/mae: 0.6799300909042358\n",
"DGT_no_spatial epoch 10 val/rmse: 0.5059577816720363 val/mae: 0.17544981837272644\n",
"DGT_no_spatial epoch 20 val/rmse: 0.4525906931368216 val/mae: 0.1393440067768097\n",
"DGT_no_spatial epoch 30 val/rmse: 0.4252298266622955 val/mae: 0.12889644503593445\n",
"DGT_no_spatial epoch 40 val/rmse: 0.400435394017216 val/mae: 0.118231400847435\n",
"DGT_no_spatial epoch 50 val/rmse: 0.36498747859667224 val/mae: 0.09809175878763199\n",
"DGT_no_spatial epoch 60 val/rmse: 0.3594312209229434 val/mae: 0.10356435179710388\n",
"DGT_no_spatial epoch 70 val/rmse: 0.3336577240219205 val/mae: 0.11112850904464722\n",
"DGT_no_spatial epoch 80 val/rmse: 0.32225048145201995 val/mae: 0.08382288366556168\n",
"DGT_no_spatial epoch 90 val/rmse: 0.31846378229186056 val/mae: 0.08361205458641052\n",
"DGT_no_spatial epoch 99 val/rmse: 0.3056713383695309 val/mae: 0.07731342315673828\n",
"DGT_no_spatial lr: 0.01 test/rmse: 1.4920656643804655 test/mae: 0.1991715133190155\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": []
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"<br> <style><br> .wandb-row {<br> display: flex;<br> flex-direction: row;<br> flex-wrap: wrap;<br> justify-content: flex-start;<br> width: 100%;<br> }<br> .wandb-col {<br> display: flex;<br> flex-direction: column;<br> flex-basis: 100%;<br> flex: 1;<br> padding: 10px;<br> }<br> </style><br><div class=\"wandb-row\"><div class=\"wandb-col\"><h3>Run history:</h3><br/><table class=\"wandb\"><tr><td>epoch</td><td>▁▁▂▂▂▂▂▂▂▂▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇██████</td></tr><tr><td>step</td><td>▃▅▃▇▃▃▄█▁▃▆▆▄▃▃▃▅▅▄▆▄▆▅▃▅▇▃▆▇▂▄▅█▄▅▆▃▃▆▆</td></tr><tr><td>test/mae</td><td>▁</td></tr><tr><td>test/rmse</td><td>▁</td></tr><tr><td>train/loss</td><td>█▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁</td></tr><tr><td>val/best_mae</td><td>▁</td></tr><tr><td>val/best_rmse</td><td>▁</td></tr><tr><td>val/mae</td><td>█▂▂▂▁▁▁▁▁▁▁</td></tr><tr><td>val/rmse</td><td>█▃▂▂▂▂▂▁▁▁▁</td></tr></table><br/></div><div class=\"wandb-col\"><h3>Run summary:</h3><br/><table class=\"wandb\"><tr><td>epoch</td><td>99</td></tr><tr><td>step</td><td>14</td></tr><tr><td>test/mae</td><td>0.19917</td></tr><tr><td>test/rmse</td><td>1.49207</td></tr><tr><td>train/loss</td><td>0.01076</td></tr><tr><td>val/best_mae</td><td>0.07731</td></tr><tr><td>val/best_rmse</td><td>0.30567</td></tr><tr><td>val/mae</td><td>0.07731</td></tr><tr><td>val/rmse</td><td>0.30567</td></tr></table><br/></div></div>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
" View run <strong style=\"color:#cdcd00\">DGT_no_spatial_lr_0.01</strong> at: <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/yd1xc1ur' target=\"_blank\">https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/yd1xc1ur</a><br> View project at: <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction' target=\"_blank\">https://wandb.ai/kevinxli/cs224w-stock-market-prediction</a><br>Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Find logs at: <code>./wandb/run-20250119_200322-yd1xc1ur/logs</code>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Tracking run with wandb version 0.19.2"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Run data is saved locally in <code>/content/wandb/run-20250119_201021-13einxap</code>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Syncing run <strong><a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/13einxap' target=\"_blank\">DGT_no_spatial_lr_0.1</a></strong> to <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/developer-guide' target=\"_blank\">docs</a>)<br>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
" View project at <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction' target=\"_blank\">https://wandb.ai/kevinxli/cs224w-stock-market-prediction</a>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
" View run at <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/13einxap' target=\"_blank\">https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/13einxap</a>"
]
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"DGT_no_spatial epoch 0 val/rmse: 0.7455923182958474 val/mae: 0.5306437015533447\n",
"DGT_no_spatial epoch 10 val/rmse: 0.2376273673249165 val/mae: 0.1642322540283203\n",
"DGT_no_spatial epoch 20 val/rmse: 0.2251414837651074 val/mae: 0.15359073877334595\n",
"DGT_no_spatial epoch 30 val/rmse: 0.2587245212901415 val/mae: 0.1711944043636322\n",
"DGT_no_spatial epoch 40 val/rmse: 0.26055001023906893 val/mae: 0.17343495786190033\n",
"DGT_no_spatial epoch 50 val/rmse: 0.2182623673329115 val/mae: 0.13900776207447052\n",
"DGT_no_spatial epoch 60 val/rmse: 0.17033703281725954 val/mae: 0.09737835824489594\n",
"DGT_no_spatial epoch 70 val/rmse: 0.21870207261523666 val/mae: 0.1400461047887802\n",
"DGT_no_spatial epoch 80 val/rmse: 0.23271536528880676 val/mae: 0.1374342441558838\n",
"DGT_no_spatial epoch 90 val/rmse: 0.2509344043279461 val/mae: 0.1355830729007721\n",
"DGT_no_spatial epoch 99 val/rmse: 0.19533162976995996 val/mae: 0.12171226739883423\n",
"DGT_no_spatial lr: 0.1 test/rmse: 0.4328992281618065 test/mae: 0.18165773153305054\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": []
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"<br> <style><br> .wandb-row {<br> display: flex;<br> flex-direction: row;<br> flex-wrap: wrap;<br> justify-content: flex-start;<br> width: 100%;<br> }<br> .wandb-col {<br> display: flex;<br> flex-direction: column;<br> flex-basis: 100%;<br> flex: 1;<br> padding: 10px;<br> }<br> </style><br><div class=\"wandb-row\"><div class=\"wandb-col\"><h3>Run history:</h3><br/><table class=\"wandb\"><tr><td>epoch</td><td>▁▁▁▁▁▂▂▂▂▂▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▅▅▅▆▆▆▇▇▇▇▇███</td></tr><tr><td>step</td><td>▃▄▄▇█▆▄▄▁▇▅▃▂▄▇▁▄▄▆▄▁▅▄▆█▂▄▅▅▇▆▃▄▅▅▅▆▆▃▄</td></tr><tr><td>test/mae</td><td>▁</td></tr><tr><td>test/rmse</td><td>▁</td></tr><tr><td>train/loss</td><td>█▆▅▂▂▂▁▁▁▂▂▁▁▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁▂▁▁▁▂▂▁▁</td></tr><tr><td>val/best_mae</td><td>▁</td></tr><tr><td>val/best_rmse</td><td>▁</td></tr><tr><td>val/mae</td><td>█▂▂▂▂▂▁▂▂▂▁</td></tr><tr><td>val/rmse</td><td>█▂▂▂▂▂▁▂▂▂▁</td></tr></table><br/></div><div class=\"wandb-col\"><h3>Run summary:</h3><br/><table class=\"wandb\"><tr><td>epoch</td><td>99</td></tr><tr><td>step</td><td>14</td></tr><tr><td>test/mae</td><td>0.18166</td></tr><tr><td>test/rmse</td><td>0.4329</td></tr><tr><td>train/loss</td><td>0.02425</td></tr><tr><td>val/best_mae</td><td>0.09738</td></tr><tr><td>val/best_rmse</td><td>0.17034</td></tr><tr><td>val/mae</td><td>0.12171</td></tr><tr><td>val/rmse</td><td>0.19533</td></tr></table><br/></div></div>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
" View run <strong style=\"color:#cdcd00\">DGT_no_spatial_lr_0.1</strong> at: <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/13einxap' target=\"_blank\">https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/13einxap</a><br> View project at: <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction' target=\"_blank\">https://wandb.ai/kevinxli/cs224w-stock-market-prediction</a><br>Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Find logs at: <code>./wandb/run-20250119_201021-13einxap/logs</code>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Tracking run with wandb version 0.19.2"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Run data is saved locally in <code>/content/wandb/run-20250119_201642-t8d39ll0</code>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Syncing run <strong><a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/t8d39ll0' target=\"_blank\">DGT_lr_0.01</a></strong> to <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/developer-guide' target=\"_blank\">docs</a>)<br>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
" View project at <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction' target=\"_blank\">https://wandb.ai/kevinxli/cs224w-stock-market-prediction</a>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
" View run at <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/t8d39ll0' target=\"_blank\">https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/t8d39ll0</a>"
]
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"DGT epoch 0 val/rmse: 1.1663328215372715 val/mae: 0.866641104221344\n",
"DGT epoch 10 val/rmse: 0.5661705842233048 val/mae: 0.18316423892974854\n",
"DGT epoch 20 val/rmse: 0.4797872657038466 val/mae: 0.141910582780838\n",
"DGT epoch 30 val/rmse: 0.46390838440108234 val/mae: 0.15902476012706757\n",
"DGT epoch 40 val/rmse: 0.4005524843141996 val/mae: 0.12455122172832489\n",
"DGT epoch 50 val/rmse: 0.40952463109017445 val/mae: 0.10992292314767838\n",
"DGT epoch 60 val/rmse: 0.42137112373146624 val/mae: 0.13949576020240784\n",
"DGT epoch 70 val/rmse: 0.38506452694684307 val/mae: 0.21908588707447052\n",
"DGT epoch 80 val/rmse: 0.3309199979114579 val/mae: 0.11764570325613022\n",
"DGT epoch 90 val/rmse: 0.33019609995222 val/mae: 0.11783812195062637\n",
"DGT epoch 99 val/rmse: 0.3275464839014915 val/mae: 0.11910999566316605\n",
"DGT lr: 0.01 test/rmse: 1.473704888039203 test/mae: 0.2925158739089966\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": []
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"<br> <style><br> .wandb-row {<br> display: flex;<br> flex-direction: row;<br> flex-wrap: wrap;<br> justify-content: flex-start;<br> width: 100%;<br> }<br> .wandb-col {<br> display: flex;<br> flex-direction: column;<br> flex-basis: 100%;<br> flex: 1;<br> padding: 10px;<br> }<br> </style><br><div class=\"wandb-row\"><div class=\"wandb-col\"><h3>Run history:</h3><br/><table class=\"wandb\"><tr><td>epoch</td><td>▁▂▂▂▂▂▂▂▂▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇████</td></tr><tr><td>step</td><td>▇▂▃█▆▅▅▄▅▁▁█▇▅▁▅▇▃▄▃▇▁█▃▃▃▆▁▅▃▆▇▃▃▂▇▆▆▄▅</td></tr><tr><td>test/mae</td><td>▁</td></tr><tr><td>test/rmse</td><td>▁</td></tr><tr><td>train/loss</td><td>█▄▅▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁</td></tr><tr><td>val/best_mae</td><td>▁</td></tr><tr><td>val/best_rmse</td><td>▁</td></tr><tr><td>val/mae</td><td>█▂▁▁▁▁▁▂▁▁▁</td></tr><tr><td>val/rmse</td><td>█▃▂▂▂▂▂▁▁▁▁</td></tr></table><br/></div><div class=\"wandb-col\"><h3>Run summary:</h3><br/><table class=\"wandb\"><tr><td>epoch</td><td>99</td></tr><tr><td>step</td><td>14</td></tr><tr><td>test/mae</td><td>0.29252</td></tr><tr><td>test/rmse</td><td>1.4737</td></tr><tr><td>train/loss</td><td>0.01151</td></tr><tr><td>val/best_mae</td><td>0.11911</td></tr><tr><td>val/best_rmse</td><td>0.32755</td></tr><tr><td>val/mae</td><td>0.11911</td></tr><tr><td>val/rmse</td><td>0.32755</td></tr></table><br/></div></div>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
" View run <strong style=\"color:#cdcd00\">DGT_lr_0.01</strong> at: <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/t8d39ll0' target=\"_blank\">https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/t8d39ll0</a><br> View project at: <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction' target=\"_blank\">https://wandb.ai/kevinxli/cs224w-stock-market-prediction</a><br>Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Find logs at: <code>./wandb/run-20250119_201642-t8d39ll0/logs</code>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Tracking run with wandb version 0.19.2"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Run data is saved locally in <code>/content/wandb/run-20250119_202832-hzbqy1h9</code>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Syncing run <strong><a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/hzbqy1h9' target=\"_blank\">DGT_lr_0.1</a></strong> to <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/developer-guide' target=\"_blank\">docs</a>)<br>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
" View project at <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction' target=\"_blank\">https://wandb.ai/kevinxli/cs224w-stock-market-prediction</a>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
" View run at <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/hzbqy1h9' target=\"_blank\">https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/hzbqy1h9</a>"
]
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"DGT epoch 0 val/rmse: 1.2528722189421675 val/mae: 1.1105965375900269\n",
"DGT epoch 10 val/rmse: 0.6066736844813188 val/mae: 0.146450012922287\n",
"DGT epoch 20 val/rmse: 0.6511204379429594 val/mae: 0.33825546503067017\n",
"DGT epoch 30 val/rmse: 0.544522374873031 val/mae: 0.2150142341852188\n",
"DGT epoch 40 val/rmse: 0.6574977185024812 val/mae: 0.3009336590766907\n",
"DGT epoch 50 val/rmse: 0.4153857768839304 val/mae: 0.1313445270061493\n",
"DGT epoch 60 val/rmse: 0.36706583564109163 val/mae: 0.13257576525211334\n",
"DGT epoch 70 val/rmse: 0.33499947151099696 val/mae: 0.1202550008893013\n",
"DGT epoch 80 val/rmse: 0.3908646420679214 val/mae: 0.1295650452375412\n",
"DGT epoch 90 val/rmse: 0.37920721626738946 val/mae: 0.18798206746578217\n",
"DGT epoch 99 val/rmse: 0.41624913355877485 val/mae: 0.2747070789337158\n",
"DGT lr: 0.1 test/rmse: 1.628739163178983 test/mae: 0.23096409440040588\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": []
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"<br> <style><br> .wandb-row {<br> display: flex;<br> flex-direction: row;<br> flex-wrap: wrap;<br> justify-content: flex-start;<br> width: 100%;<br> }<br> .wandb-col {<br> display: flex;<br> flex-direction: column;<br> flex-basis: 100%;<br> flex: 1;<br> padding: 10px;<br> }<br> </style><br><div class=\"wandb-row\"><div class=\"wandb-col\"><h3>Run history:</h3><br/><table class=\"wandb\"><tr><td>epoch</td><td>▁▁▁▁▁▂▂▂▃▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇█</td></tr><tr><td>step</td><td>▅▁▄▅▄▁▃▁▅█▁▇▃▃▇▆▇▂▁▁▆▆▃▄▂▁▃▂▃▃▅▅▆▇█▁▁█▄▂</td></tr><tr><td>test/mae</td><td>▁</td></tr><tr><td>test/rmse</td><td>▁</td></tr><tr><td>train/loss</td><td>▃▂▃▃▃█▂▂▂▄▂▃▅▄▂▁▁▁▁▃▁▃▃▂▁▃▂▆▁▃▃▁▂█▂▂▃▆▁▃</td></tr><tr><td>val/best_mae</td><td>▁</td></tr><tr><td>val/best_rmse</td><td>▁</td></tr><tr><td>val/mae</td><td>█▁▃▂▂▁▁▁▁▁▂</td></tr><tr><td>val/rmse</td><td>█▃▃▃▃▂▁▁▁▁▂</td></tr></table><br/></div><div class=\"wandb-col\"><h3>Run summary:</h3><br/><table class=\"wandb\"><tr><td>epoch</td><td>99</td></tr><tr><td>step</td><td>14</td></tr><tr><td>test/mae</td><td>0.23096</td></tr><tr><td>test/rmse</td><td>1.62874</td></tr><tr><td>train/loss</td><td>0.08559</td></tr><tr><td>val/best_mae</td><td>0.12026</td></tr><tr><td>val/best_rmse</td><td>0.335</td></tr><tr><td>val/mae</td><td>0.27471</td></tr><tr><td>val/rmse</td><td>0.41625</td></tr></table><br/></div></div>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
" View run <strong style=\"color:#cdcd00\">DGT_lr_0.1</strong> at: <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/hzbqy1h9' target=\"_blank\">https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/hzbqy1h9</a><br> View project at: <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction' target=\"_blank\">https://wandb.ai/kevinxli/cs224w-stock-market-prediction</a><br>Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Find logs at: <code>./wandb/run-20250119_202832-hzbqy1h9/logs</code>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Tracking run with wandb version 0.19.2"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Run data is saved locally in <code>/content/wandb/run-20250119_204024-epf8kmbr</code>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Syncing run <strong><a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/epf8kmbr' target=\"_blank\">DGT_mi_global_lr_0.01</a></strong> to <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/developer-guide' target=\"_blank\">docs</a>)<br>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
" View project at <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction' target=\"_blank\">https://wandb.ai/kevinxli/cs224w-stock-market-prediction</a>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
" View run at <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/epf8kmbr' target=\"_blank\">https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/epf8kmbr</a>"
]
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"DGT_mi_global epoch 0 val/rmse: 1.0012687740399335 val/mae: 0.6714795827865601\n",
"DGT_mi_global epoch 10 val/rmse: 0.5359844843341077 val/mae: 0.31612247228622437\n",
"DGT_mi_global epoch 20 val/rmse: 0.502615406391655 val/mae: 0.2863154113292694\n",
"DGT_mi_global epoch 30 val/rmse: 0.3854373161695279 val/mae: 0.23913811147212982\n",
"DGT_mi_global epoch 40 val/rmse: 0.19174504949229512 val/mae: 0.11173715442419052\n",
"DGT_mi_global epoch 50 val/rmse: 0.19272254341589656 val/mae: 0.12769179046154022\n",
"DGT_mi_global epoch 60 val/rmse: 0.17119214568601335 val/mae: 0.1005631685256958\n",
"DGT_mi_global epoch 70 val/rmse: 0.15980018018186004 val/mae: 0.10093577951192856\n",
"DGT_mi_global epoch 80 val/rmse: 0.14929510504782809 val/mae: 0.08889802545309067\n",
"DGT_mi_global epoch 90 val/rmse: 0.14006206637412835 val/mae: 0.08726727962493896\n",
"DGT_mi_global epoch 99 val/rmse: 0.19063454393040083 val/mae: 0.11815609037876129\n",
"DGT_mi_global lr: 0.01 test/rmse: 0.46510339746130386 test/mae: 0.11491794139146805\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": []
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"<br> <style><br> .wandb-row {<br> display: flex;<br> flex-direction: row;<br> flex-wrap: wrap;<br> justify-content: flex-start;<br> width: 100%;<br> }<br> .wandb-col {<br> display: flex;<br> flex-direction: column;<br> flex-basis: 100%;<br> flex: 1;<br> padding: 10px;<br> }<br> </style><br><div class=\"wandb-row\"><div class=\"wandb-col\"><h3>Run history:</h3><br/><table class=\"wandb\"><tr><td>epoch</td><td>▁▁▁▁▂▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▆▆▆▇█████</td></tr><tr><td>step</td><td>▅▇▁▆▅▅█▅█▇▁▇▂▃▅▂▃▆▇▄▃▅▅▄▃▇▁▅▇▅▅▃▂▄▇▃▁▃█▇</td></tr><tr><td>test/mae</td><td>▁</td></tr><tr><td>test/rmse</td><td>▁</td></tr><tr><td>train/loss</td><td>█▆▃▂▂▄▂▂▂▂▂▁▁▁▁▁▁▁▂▁▁▁▂▂▁▂▁▁▁▁▁▁▂▁▂▁▁▁▁▁</td></tr><tr><td>val/best_mae</td><td>▁</td></tr><tr><td>val/best_rmse</td><td>▁</td></tr><tr><td>val/mae</td><td>█▄▃▃▁▁▁▁▁▁▁</td></tr><tr><td>val/rmse</td><td>█▄▄▃▁▁▁▁▁▁▁</td></tr></table><br/></div><div class=\"wandb-col\"><h3>Run summary:</h3><br/><table class=\"wandb\"><tr><td>epoch</td><td>99</td></tr><tr><td>step</td><td>14</td></tr><tr><td>test/mae</td><td>0.11492</td></tr><tr><td>test/rmse</td><td>0.4651</td></tr><tr><td>train/loss</td><td>0.01955</td></tr><tr><td>val/best_mae</td><td>0.08727</td></tr><tr><td>val/best_rmse</td><td>0.14006</td></tr><tr><td>val/mae</td><td>0.11816</td></tr><tr><td>val/rmse</td><td>0.19063</td></tr></table><br/></div></div>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
" View run <strong style=\"color:#cdcd00\">DGT_mi_global_lr_0.01</strong> at: <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/epf8kmbr' target=\"_blank\">https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/epf8kmbr</a><br> View project at: <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction' target=\"_blank\">https://wandb.ai/kevinxli/cs224w-stock-market-prediction</a><br>Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Find logs at: <code>./wandb/run-20250119_204024-epf8kmbr/logs</code>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Tracking run with wandb version 0.19.2"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Run data is saved locally in <code>/content/wandb/run-20250119_205235-kok5k8bg</code>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Syncing run <strong><a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/kok5k8bg' target=\"_blank\">DGT_mi_global_lr_0.1</a></strong> to <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/developer-guide' target=\"_blank\">docs</a>)<br>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
" View project at <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction' target=\"_blank\">https://wandb.ai/kevinxli/cs224w-stock-market-prediction</a>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
" View run at <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/kok5k8bg' target=\"_blank\">https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/kok5k8bg</a>"
]
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"DGT_mi_global epoch 0 val/rmse: 0.8252649402022962 val/mae: 0.6091215014457703\n",
"DGT_mi_global epoch 10 val/rmse: 0.628280778802649 val/mae: 0.20826566219329834\n",
"DGT_mi_global epoch 20 val/rmse: 0.4549073647948406 val/mae: 0.13138806819915771\n",
"DGT_mi_global epoch 30 val/rmse: 0.43879581035654247 val/mae: 0.20296455919742584\n",
"DGT_mi_global epoch 40 val/rmse: 0.3909421777483479 val/mae: 0.0946093201637268\n",
"DGT_mi_global epoch 50 val/rmse: 0.37161972516393743 val/mae: 0.10006465017795563\n",
"DGT_mi_global epoch 60 val/rmse: 0.5672059300478544 val/mae: 0.4679655432701111\n",
"DGT_mi_global epoch 70 val/rmse: 0.6988751083794231 val/mae: 0.4168681204319\n",
"DGT_mi_global epoch 80 val/rmse: 0.6391157421886711 val/mae: 0.21604809165000916\n",
"DGT_mi_global epoch 90 val/rmse: 0.6124751874218606 val/mae: 0.2516571283340454\n",
"DGT_mi_global epoch 99 val/rmse: 0.8388722322998409 val/mae: 0.17682863771915436\n",
"DGT_mi_global lr: 0.1 test/rmse: 1.6498760436920556 test/mae: 0.1788974553346634\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": []
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"<br> <style><br> .wandb-row {<br> display: flex;<br> flex-direction: row;<br> flex-wrap: wrap;<br> justify-content: flex-start;<br> width: 100%;<br> }<br> .wandb-col {<br> display: flex;<br> flex-direction: column;<br> flex-basis: 100%;<br> flex: 1;<br> padding: 10px;<br> }<br> </style><br><div class=\"wandb-row\"><div class=\"wandb-col\"><h3>Run history:</h3><br/><table class=\"wandb\"><tr><td>epoch</td><td>▁▁▁▂▂▂▃▃▃▃▃▄▄▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇▇████</td></tr><tr><td>step</td><td>▂▁▅▁▃▅▆▇▃▅▅▅▅▇█▂▁▃▂▅█▄▃▇█▆▂▂▆▃▄▅▆▇▁▅▃▁▂▅</td></tr><tr><td>test/mae</td><td>▁</td></tr><tr><td>test/rmse</td><td>▁</td></tr><tr><td>train/loss</td><td>▃█▇▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁</td></tr><tr><td>val/best_mae</td><td>▁</td></tr><tr><td>val/best_rmse</td><td>▁</td></tr><tr><td>val/mae</td><td>█▃▂▂▁▁▆▅▃▃▂</td></tr><tr><td>val/rmse</td><td>█▅▂▂▁▁▄▆▅▅█</td></tr></table><br/></div><div class=\"wandb-col\"><h3>Run summary:</h3><br/><table class=\"wandb\"><tr><td>epoch</td><td>99</td></tr><tr><td>step</td><td>14</td></tr><tr><td>test/mae</td><td>0.1789</td></tr><tr><td>test/rmse</td><td>1.64988</td></tr><tr><td>train/loss</td><td>0.03201</td></tr><tr><td>val/best_mae</td><td>0.10006</td></tr><tr><td>val/best_rmse</td><td>0.37162</td></tr><tr><td>val/mae</td><td>0.17683</td></tr><tr><td>val/rmse</td><td>0.83887</td></tr></table><br/></div></div>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
" View run <strong style=\"color:#cdcd00\">DGT_mi_global_lr_0.1</strong> at: <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/kok5k8bg' target=\"_blank\">https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/kok5k8bg</a><br> View project at: <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction' target=\"_blank\">https://wandb.ai/kevinxli/cs224w-stock-market-prediction</a><br>Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Find logs at: <code>./wandb/run-20250119_205235-kok5k8bg/logs</code>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Tracking run with wandb version 0.19.2"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Run data is saved locally in <code>/content/wandb/run-20250119_210454-kctbq2yn</code>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Syncing run <strong><a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/kctbq2yn' target=\"_blank\">DGT_mi_local_lr_0.01</a></strong> to <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/developer-guide' target=\"_blank\">docs</a>)<br>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
" View project at <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction' target=\"_blank\">https://wandb.ai/kevinxli/cs224w-stock-market-prediction</a>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
" View run at <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/kctbq2yn' target=\"_blank\">https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/kctbq2yn</a>"
]
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"DGT_mi_local epoch 0 val/rmse: 0.9002928111801067 val/mae: 0.5393624305725098\n",
"DGT_mi_local epoch 10 val/rmse: 0.6887280160542051 val/mae: 0.4005378186702728\n",
"DGT_mi_local epoch 20 val/rmse: 0.4494036915762458 val/mae: 0.2704002857208252\n",
"DGT_mi_local epoch 30 val/rmse: 0.2686673014951199 val/mae: 0.14062510430812836\n",
"DGT_mi_local epoch 40 val/rmse: 0.2844522879445983 val/mae: 0.20511041581630707\n",
"DGT_mi_local epoch 50 val/rmse: 0.16961481588168448 val/mae: 0.11053712666034698\n",
"DGT_mi_local epoch 60 val/rmse: 0.17120601223983523 val/mae: 0.10645116120576859\n",
"DGT_mi_local epoch 70 val/rmse: 0.18160796128140974 val/mae: 0.1133645623922348\n",
"DGT_mi_local epoch 80 val/rmse: 0.1421013135967102 val/mae: 0.09032121300697327\n",
"DGT_mi_local epoch 90 val/rmse: 0.15679185611236277 val/mae: 0.09798093140125275\n",
"DGT_mi_local epoch 99 val/rmse: 0.14165065177569927 val/mae: 0.10277072340250015\n",
"DGT_mi_local lr: 0.01 test/rmse: 0.3241913680497987 test/mae: 0.13397999107837677\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": []
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"<br> <style><br> .wandb-row {<br> display: flex;<br> flex-direction: row;<br> flex-wrap: wrap;<br> justify-content: flex-start;<br> width: 100%;<br> }<br> .wandb-col {<br> display: flex;<br> flex-direction: column;<br> flex-basis: 100%;<br> flex: 1;<br> padding: 10px;<br> }<br> </style><br><div class=\"wandb-row\"><div class=\"wandb-col\"><h3>Run history:</h3><br/><table class=\"wandb\"><tr><td>epoch</td><td>▁▁▁▁▁▂▂▂▂▂▂▂▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇██</td></tr><tr><td>step</td><td>▂▁▇▇▁▅▅█▇▃▃▂▃▂▁▇▂▅▃▃▁▃▃▅▅▅▅▇▇▇▅▂▁▆▅▁▃▇▇▄</td></tr><tr><td>test/mae</td><td>▁</td></tr><tr><td>test/rmse</td><td>▁</td></tr><tr><td>train/loss</td><td>█▂▁▂▂▂▂▃▂▂▃▂▄▂▂▃▁▂▇▅▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁</td></tr><tr><td>val/best_mae</td><td>▁</td></tr><tr><td>val/best_rmse</td><td>▁</td></tr><tr><td>val/mae</td><td>█▆▄▂▃▁▁▁▁▁▁</td></tr><tr><td>val/rmse</td><td>█▆▄▂▂▁▁▁▁▁▁</td></tr></table><br/></div><div class=\"wandb-col\"><h3>Run summary:</h3><br/><table class=\"wandb\"><tr><td>epoch</td><td>99</td></tr><tr><td>step</td><td>14</td></tr><tr><td>test/mae</td><td>0.13398</td></tr><tr><td>test/rmse</td><td>0.32419</td></tr><tr><td>train/loss</td><td>0.01362</td></tr><tr><td>val/best_mae</td><td>0.10277</td></tr><tr><td>val/best_rmse</td><td>0.14165</td></tr><tr><td>val/mae</td><td>0.10277</td></tr><tr><td>val/rmse</td><td>0.14165</td></tr></table><br/></div></div>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
" View run <strong style=\"color:#cdcd00\">DGT_mi_local_lr_0.01</strong> at: <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/kctbq2yn' target=\"_blank\">https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/kctbq2yn</a><br> View project at: <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction' target=\"_blank\">https://wandb.ai/kevinxli/cs224w-stock-market-prediction</a><br>Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Find logs at: <code>./wandb/run-20250119_210454-kctbq2yn/logs</code>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Tracking run with wandb version 0.19.2"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Run data is saved locally in <code>/content/wandb/run-20250119_211716-cxxigb15</code>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Syncing run <strong><a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/cxxigb15' target=\"_blank\">DGT_mi_local_lr_0.1</a></strong> to <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/developer-guide' target=\"_blank\">docs</a>)<br>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
" View project at <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction' target=\"_blank\">https://wandb.ai/kevinxli/cs224w-stock-market-prediction</a>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
" View run at <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/cxxigb15' target=\"_blank\">https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/cxxigb15</a>"
]
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"DGT_mi_local epoch 0 val/rmse: 1.037091673504184 val/mae: 0.5196841359138489\n",
"DGT_mi_local epoch 10 val/rmse: 1.060045773759421 val/mae: 0.7239880561828613\n",
"DGT_mi_local epoch 20 val/rmse: 0.9378771341252035 val/mae: 0.2847558259963989\n",
"DGT_mi_local epoch 30 val/rmse: 0.7247933619443653 val/mae: 0.2686854898929596\n",
"DGT_mi_local epoch 40 val/rmse: 1.01728409456467 val/mae: 0.28426408767700195\n",
"DGT_mi_local epoch 50 val/rmse: 0.9637774857635276 val/mae: 0.25843414664268494\n",
"DGT_mi_local epoch 60 val/rmse: 0.7276040272222197 val/mae: 0.1767144352197647\n",
"DGT_mi_local epoch 70 val/rmse: 1.1755456022944113 val/mae: 0.3743884861469269\n",
"DGT_mi_local epoch 80 val/rmse: 1.273515897806638 val/mae: 0.3659111261367798\n",
"DGT_mi_local epoch 90 val/rmse: 0.8872466343557628 val/mae: 0.31716471910476685\n",
"DGT_mi_local epoch 99 val/rmse: 1.2860957830452633 val/mae: 0.24827918410301208\n",
"DGT_mi_local lr: 0.1 test/rmse: 3.4772281995201415 test/mae: 0.6231410503387451\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": []
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"<br> <style><br> .wandb-row {<br> display: flex;<br> flex-direction: row;<br> flex-wrap: wrap;<br> justify-content: flex-start;<br> width: 100%;<br> }<br> .wandb-col {<br> display: flex;<br> flex-direction: column;<br> flex-basis: 100%;<br> flex: 1;<br> padding: 10px;<br> }<br> </style><br><div class=\"wandb-row\"><div class=\"wandb-col\"><h3>Run history:</h3><br/><table class=\"wandb\"><tr><td>epoch</td><td>▁▁▁▁▁▁▁▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▇▇▇▇▇▇████</td></tr><tr><td>step</td><td>▇▁▇▃▁▁▇▆▄▃▇▄▃▇▁▂▁▂▂▅▇▅▅▇▁▇▇▁▃▅▇▅▁▁█▁▇█▅█</td></tr><tr><td>test/mae</td><td>▁</td></tr><tr><td>test/rmse</td><td>▁</td></tr><tr><td>train/loss</td><td>█▅▂▁▁▁▁▁▁▁▃▃▁▁▂▂▁▁▁▁▁▂▂▃▁▁▂▂▂▁▂▂▄▁▂▃▁▁▁▁</td></tr><tr><td>val/best_mae</td><td>▁</td></tr><tr><td>val/best_rmse</td><td>▁</td></tr><tr><td>val/mae</td><td>▅█▂▂▂▂▁▄▃▃▂</td></tr><tr><td>val/rmse</td><td>▅▅▄▁▅▄▁▇█▃█</td></tr></table><br/></div><div class=\"wandb-col\"><h3>Run summary:</h3><br/><table class=\"wandb\"><tr><td>epoch</td><td>99</td></tr><tr><td>step</td><td>14</td></tr><tr><td>test/mae</td><td>0.62314</td></tr><tr><td>test/rmse</td><td>3.47723</td></tr><tr><td>train/loss</td><td>0.03594</td></tr><tr><td>val/best_mae</td><td>0.26869</td></tr><tr><td>val/best_rmse</td><td>0.72479</td></tr><tr><td>val/mae</td><td>0.24828</td></tr><tr><td>val/rmse</td><td>1.2861</td></tr></table><br/></div></div>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
" View run <strong style=\"color:#cdcd00\">DGT_mi_local_lr_0.1</strong> at: <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/cxxigb15' target=\"_blank\">https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/cxxigb15</a><br> View project at: <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction' target=\"_blank\">https://wandb.ai/kevinxli/cs224w-stock-market-prediction</a><br>Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Find logs at: <code>./wandb/run-20250119_211716-cxxigb15/logs</code>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Tracking run with wandb version 0.19.2"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Run data is saved locally in <code>/content/wandb/run-20250119_212916-ua5u9agb</code>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Syncing run <strong><a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/ua5u9agb' target=\"_blank\">DGT_mi_dual_lr_0.01</a></strong> to <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/developer-guide' target=\"_blank\">docs</a>)<br>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
" View project at <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction' target=\"_blank\">https://wandb.ai/kevinxli/cs224w-stock-market-prediction</a>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
" View run at <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/ua5u9agb' target=\"_blank\">https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/ua5u9agb</a>"
]
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"DGT_mi_dual epoch 0 val/rmse: 0.7391773013740808 val/mae: 0.45560505986213684\n",
"DGT_mi_dual epoch 10 val/rmse: 0.5032932189825671 val/mae: 0.3045569658279419\n",
"DGT_mi_dual epoch 20 val/rmse: 0.39251623092792287 val/mae: 0.2444455772638321\n",
"DGT_mi_dual epoch 30 val/rmse: 0.37159131473368523 val/mae: 0.22902081906795502\n",
"DGT_mi_dual epoch 40 val/rmse: 0.2654297756394218 val/mae: 0.16484399139881134\n",
"DGT_mi_dual epoch 50 val/rmse: 0.25267701065461456 val/mae: 0.15037177503108978\n",
"DGT_mi_dual epoch 60 val/rmse: 0.2815375579938595 val/mae: 0.17735306918621063\n",
"DGT_mi_dual epoch 70 val/rmse: 0.1762174650948065 val/mae: 0.11425253003835678\n",
"DGT_mi_dual epoch 80 val/rmse: 0.16801417091504559 val/mae: 0.1077650710940361\n",
"DGT_mi_dual epoch 90 val/rmse: 0.20235126105842394 val/mae: 0.168687105178833\n",
"DGT_mi_dual epoch 99 val/rmse: 0.1216295180214235 val/mae: 0.0851675346493721\n",
"DGT_mi_dual lr: 0.01 test/rmse: 0.2598891143975098 test/mae: 0.09928729385137558\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": []
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"<br> <style><br> .wandb-row {<br> display: flex;<br> flex-direction: row;<br> flex-wrap: wrap;<br> justify-content: flex-start;<br> width: 100%;<br> }<br> .wandb-col {<br> display: flex;<br> flex-direction: column;<br> flex-basis: 100%;<br> flex: 1;<br> padding: 10px;<br> }<br> </style><br><div class=\"wandb-row\"><div class=\"wandb-col\"><h3>Run history:</h3><br/><table class=\"wandb\"><tr><td>epoch</td><td>▁▁▁▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇████</td></tr><tr><td>step</td><td>▃▃▂▄▅▅▆▇▅▂▅▇▂▆▁▅█▇▃▄▅▅█▆▇▃▅▂▃▆▃▇▆▇▃▅▃▃█▅</td></tr><tr><td>test/mae</td><td>▁</td></tr><tr><td>test/rmse</td><td>▁</td></tr><tr><td>train/loss</td><td>▄█▃▂▁▃▅▃▃▃▁▁▁▁▁▃▂▁▁▄▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▄▁▁</td></tr><tr><td>val/best_mae</td><td>▁</td></tr><tr><td>val/best_rmse</td><td>▁</td></tr><tr><td>val/mae</td><td>█▅▄▄▃▂▃▂▁▃▁</td></tr><tr><td>val/rmse</td><td>█▅▄▄▃▂▃▂▂▂▁</td></tr></table><br/></div><div class=\"wandb-col\"><h3>Run summary:</h3><br/><table class=\"wandb\"><tr><td>epoch</td><td>99</td></tr><tr><td>step</td><td>14</td></tr><tr><td>test/mae</td><td>0.09929</td></tr><tr><td>test/rmse</td><td>0.25989</td></tr><tr><td>train/loss</td><td>0.0151</td></tr><tr><td>val/best_mae</td><td>0.08517</td></tr><tr><td>val/best_rmse</td><td>0.12163</td></tr><tr><td>val/mae</td><td>0.08517</td></tr><tr><td>val/rmse</td><td>0.12163</td></tr></table><br/></div></div>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
" View run <strong style=\"color:#cdcd00\">DGT_mi_dual_lr_0.01</strong> at: <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/ua5u9agb' target=\"_blank\">https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/ua5u9agb</a><br> View project at: <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction' target=\"_blank\">https://wandb.ai/kevinxli/cs224w-stock-market-prediction</a><br>Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Find logs at: <code>./wandb/run-20250119_212916-ua5u9agb/logs</code>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Tracking run with wandb version 0.19.2"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Run data is saved locally in <code>/content/wandb/run-20250119_214329-yctt35b7</code>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Syncing run <strong><a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/yctt35b7' target=\"_blank\">DGT_mi_dual_lr_0.1</a></strong> to <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/developer-guide' target=\"_blank\">docs</a>)<br>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
" View project at <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction' target=\"_blank\">https://wandb.ai/kevinxli/cs224w-stock-market-prediction</a>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
" View run at <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/yctt35b7' target=\"_blank\">https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/yctt35b7</a>"
]
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"DGT_mi_dual epoch 0 val/rmse: 0.6942267996085015 val/mae: 0.43968915939331055\n",
"DGT_mi_dual epoch 10 val/rmse: 1.513055097374729 val/mae: 1.2570449113845825\n",
"DGT_mi_dual epoch 20 val/rmse: 0.21284556130644008 val/mae: 0.12287840992212296\n",
"DGT_mi_dual epoch 30 val/rmse: 0.29924503240868056 val/mae: 0.218398317694664\n",
"DGT_mi_dual epoch 40 val/rmse: 0.25668211417240455 val/mae: 0.18500438332557678\n",
"DGT_mi_dual epoch 50 val/rmse: 0.31338376011452096 val/mae: 0.2572592794895172\n",
"DGT_mi_dual epoch 60 val/rmse: 0.29370700643298053 val/mae: 0.24092897772789001\n",
"DGT_mi_dual epoch 70 val/rmse: 0.2481150796189882 val/mae: 0.1967838853597641\n",
"DGT_mi_dual epoch 80 val/rmse: 0.24141548534122279 val/mae: 0.17510837316513062\n",
"DGT_mi_dual epoch 90 val/rmse: 0.2549060211723098 val/mae: 0.1469665765762329\n",
"DGT_mi_dual epoch 99 val/rmse: 0.18624654659487494 val/mae: 0.13340604305267334\n",
"DGT_mi_dual lr: 0.1 test/rmse: 0.6633853826118248 test/mae: 0.21146634221076965\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": []
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"<br> <style><br> .wandb-row {<br> display: flex;<br> flex-direction: row;<br> flex-wrap: wrap;<br> justify-content: flex-start;<br> width: 100%;<br> }<br> .wandb-col {<br> display: flex;<br> flex-direction: column;<br> flex-basis: 100%;<br> flex: 1;<br> padding: 10px;<br> }<br> </style><br><div class=\"wandb-row\"><div class=\"wandb-col\"><h3>Run history:</h3><br/><table class=\"wandb\"><tr><td>epoch</td><td>▁▁▁▁▂▂▂▂▂▃▃▃▄▄▄▅▅▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇▇▇▇████</td></tr><tr><td>step</td><td>▇█▅▃▃▁▃▃▃██▃▃▁▁▁▃▅▇▃▅▃▁▇▃▅▅█▃▅▄▃▅█▁▇▃██▃</td></tr><tr><td>test/mae</td><td>▁</td></tr><tr><td>test/rmse</td><td>▁</td></tr><tr><td>train/loss</td><td>▆▃▃██▂▂▁▂▁▁▁▁▁▁▂▁▁▁▂▆▂▁▂▁▂▂▂▁▂▁▁▃▂▁▂▂▂▁▂</td></tr><tr><td>val/best_mae</td><td>▁</td></tr><tr><td>val/best_rmse</td><td>▁</td></tr><tr><td>val/mae</td><td>▃█▁▂▁▂▂▁▁▁▁</td></tr><tr><td>val/rmse</td><td>▄█▁▂▁▂▂▁▁▁▁</td></tr></table><br/></div><div class=\"wandb-col\"><h3>Run summary:</h3><br/><table class=\"wandb\"><tr><td>epoch</td><td>99</td></tr><tr><td>step</td><td>14</td></tr><tr><td>test/mae</td><td>0.21147</td></tr><tr><td>test/rmse</td><td>0.66339</td></tr><tr><td>train/loss</td><td>0.02027</td></tr><tr><td>val/best_mae</td><td>0.13341</td></tr><tr><td>val/best_rmse</td><td>0.18625</td></tr><tr><td>val/mae</td><td>0.13341</td></tr><tr><td>val/rmse</td><td>0.18625</td></tr></table><br/></div></div>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
" View run <strong style=\"color:#cdcd00\">DGT_mi_dual_lr_0.1</strong> at: <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/yctt35b7' target=\"_blank\">https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/yctt35b7</a><br> View project at: <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction' target=\"_blank\">https://wandb.ai/kevinxli/cs224w-stock-market-prediction</a><br>Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Find logs at: <code>./wandb/run-20250119_214329-yctt35b7/logs</code>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Tracking run with wandb version 0.19.2"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Run data is saved locally in <code>/content/wandb/run-20250119_215741-8x4pwb05</code>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Syncing run <strong><a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/8x4pwb05' target=\"_blank\">DGT_pcc_global_lr_0.01</a></strong> to <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/developer-guide' target=\"_blank\">docs</a>)<br>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
" View project at <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction' target=\"_blank\">https://wandb.ai/kevinxli/cs224w-stock-market-prediction</a>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
" View run at <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/8x4pwb05' target=\"_blank\">https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/8x4pwb05</a>"
]
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"DGT_pcc_global epoch 0 val/rmse: 1.066074697074617 val/mae: 0.777685821056366\n",
"DGT_pcc_global epoch 10 val/rmse: 0.577958830588208 val/mae: 0.35283103585243225\n",
"DGT_pcc_global epoch 20 val/rmse: 0.3581205532869752 val/mae: 0.18000569939613342\n",
"DGT_pcc_global epoch 30 val/rmse: 0.28523552458925105 val/mae: 0.15971672534942627\n",
"DGT_pcc_global epoch 40 val/rmse: 0.19334186127848263 val/mae: 0.13116315007209778\n",
"DGT_pcc_global epoch 50 val/rmse: 0.23222657822078718 val/mae: 0.10899177938699722\n",
"DGT_pcc_global epoch 60 val/rmse: 0.23714674886052958 val/mae: 0.1874692142009735\n",
"DGT_pcc_global epoch 70 val/rmse: 0.1837740175451858 val/mae: 0.09711305797100067\n",
"DGT_pcc_global epoch 80 val/rmse: 0.18948618610215193 val/mae: 0.12731781601905823\n",
"DGT_pcc_global epoch 90 val/rmse: 0.1741502755572811 val/mae: 0.0949571430683136\n",
"DGT_pcc_global epoch 99 val/rmse: 0.27686721534635217 val/mae: 0.14365161955356598\n",
"DGT_pcc_global lr: 0.01 test/rmse: 0.6430229572051165 test/mae: 0.14795534312725067\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": []
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"<br> <style><br> .wandb-row {<br> display: flex;<br> flex-direction: row;<br> flex-wrap: wrap;<br> justify-content: flex-start;<br> width: 100%;<br> }<br> .wandb-col {<br> display: flex;<br> flex-direction: column;<br> flex-basis: 100%;<br> flex: 1;<br> padding: 10px;<br> }<br> </style><br><div class=\"wandb-row\"><div class=\"wandb-col\"><h3>Run history:</h3><br/><table class=\"wandb\"><tr><td>epoch</td><td>▁▁▂▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇█</td></tr><tr><td>step</td><td>▃▃█▄▁▆▃▅▃▅▅▃▅▇▇▃▄▁▇▅▇▆█▅▇▅▄▆▁▂▅▇▃▇▆▂█▅▃▅</td></tr><tr><td>test/mae</td><td>▁</td></tr><tr><td>test/rmse</td><td>▁</td></tr><tr><td>train/loss</td><td>▄▁▂▂▂▂▂▁▁█▄▂▂▁▁▁▁▁▁▁▁▁▁▅▃▁▁▁▁▁▁▃▁▂▁▁▁▁▁▁</td></tr><tr><td>val/best_mae</td><td>▁</td></tr><tr><td>val/best_rmse</td><td>▁</td></tr><tr><td>val/mae</td><td>█▄▂▂▁▁▂▁▁▁▁</td></tr><tr><td>val/rmse</td><td>█▄▂▂▁▁▁▁▁▁▂</td></tr></table><br/></div><div class=\"wandb-col\"><h3>Run summary:</h3><br/><table class=\"wandb\"><tr><td>epoch</td><td>99</td></tr><tr><td>step</td><td>14</td></tr><tr><td>test/mae</td><td>0.14796</td></tr><tr><td>test/rmse</td><td>0.64302</td></tr><tr><td>train/loss</td><td>0.03702</td></tr><tr><td>val/best_mae</td><td>0.09496</td></tr><tr><td>val/best_rmse</td><td>0.17415</td></tr><tr><td>val/mae</td><td>0.14365</td></tr><tr><td>val/rmse</td><td>0.27687</td></tr></table><br/></div></div>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
" View run <strong style=\"color:#cdcd00\">DGT_pcc_global_lr_0.01</strong> at: <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/8x4pwb05' target=\"_blank\">https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/8x4pwb05</a><br> View project at: <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction' target=\"_blank\">https://wandb.ai/kevinxli/cs224w-stock-market-prediction</a><br>Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Find logs at: <code>./wandb/run-20250119_215741-8x4pwb05/logs</code>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Tracking run with wandb version 0.19.2"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Run data is saved locally in <code>/content/wandb/run-20250119_220906-bpsv7yp4</code>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Syncing run <strong><a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/bpsv7yp4' target=\"_blank\">DGT_pcc_global_lr_0.1</a></strong> to <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/developer-guide' target=\"_blank\">docs</a>)<br>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
" View project at <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction' target=\"_blank\">https://wandb.ai/kevinxli/cs224w-stock-market-prediction</a>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
" View run at <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/bpsv7yp4' target=\"_blank\">https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/bpsv7yp4</a>"
]
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"DGT_pcc_global epoch 0 val/rmse: 0.8879743584397527 val/mae: 0.5402513146400452\n",
"DGT_pcc_global epoch 10 val/rmse: 1.8407399450916289 val/mae: 0.2869521677494049\n",
"DGT_pcc_global epoch 20 val/rmse: 1.6265783714084423 val/mae: 0.20345649123191833\n",
"DGT_pcc_global epoch 30 val/rmse: 1.4992064125145774 val/mae: 0.262802392244339\n",
"DGT_pcc_global epoch 40 val/rmse: 1.3776486635383223 val/mae: 0.39483073353767395\n",
"DGT_pcc_global epoch 50 val/rmse: 1.6133821100291263 val/mae: 0.1863420605659485\n",
"DGT_pcc_global epoch 60 val/rmse: 1.7617123216735893 val/mae: 0.224675253033638\n",
"DGT_pcc_global epoch 70 val/rmse: 1.9495031726475758 val/mae: 0.42596253752708435\n",
"DGT_pcc_global epoch 80 val/rmse: 1.72775696738175 val/mae: 0.18614937365055084\n",
"DGT_pcc_global epoch 90 val/rmse: 1.7126741696633032 val/mae: 0.20532815158367157\n",
"DGT_pcc_global epoch 99 val/rmse: 1.6333274211906297 val/mae: 0.270331472158432\n",
"DGT_pcc_global lr: 0.1 test/rmse: 3.1046794152118564 test/mae: 0.8780212998390198\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": []
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"<br> <style><br> .wandb-row {<br> display: flex;<br> flex-direction: row;<br> flex-wrap: wrap;<br> justify-content: flex-start;<br> width: 100%;<br> }<br> .wandb-col {<br> display: flex;<br> flex-direction: column;<br> flex-basis: 100%;<br> flex: 1;<br> padding: 10px;<br> }<br> </style><br><div class=\"wandb-row\"><div class=\"wandb-col\"><h3>Run history:</h3><br/><table class=\"wandb\"><tr><td>epoch</td><td>▁▁▂▂▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇█</td></tr><tr><td>step</td><td>▇█▃▁█▅█▃▃▁▃▁▃▃▅▁▅▂▆▇▇▄▅▅▅▃▃▁▇▁▃█▃▂▅▁▃▁█▇</td></tr><tr><td>test/mae</td><td>▁</td></tr><tr><td>test/rmse</td><td>▁</td></tr><tr><td>train/loss</td><td>█▅█▂▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▂▁▁▁▁▂▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁</td></tr><tr><td>val/best_mae</td><td>▁</td></tr><tr><td>val/best_rmse</td><td>▁</td></tr><tr><td>val/mae</td><td>█▃▁▃▅▁▂▆▁▁▃</td></tr><tr><td>val/rmse</td><td>▁▇▆▅▄▆▇█▇▆▆</td></tr></table><br/></div><div class=\"wandb-col\"><h3>Run summary:</h3><br/><table class=\"wandb\"><tr><td>epoch</td><td>99</td></tr><tr><td>step</td><td>14</td></tr><tr><td>test/mae</td><td>0.87802</td></tr><tr><td>test/rmse</td><td>3.10468</td></tr><tr><td>train/loss</td><td>0.05394</td></tr><tr><td>val/best_mae</td><td>0.54025</td></tr><tr><td>val/best_rmse</td><td>0.88797</td></tr><tr><td>val/mae</td><td>0.27033</td></tr><tr><td>val/rmse</td><td>1.63333</td></tr></table><br/></div></div>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
" View run <strong style=\"color:#cdcd00\">DGT_pcc_global_lr_0.1</strong> at: <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/bpsv7yp4' target=\"_blank\">https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/bpsv7yp4</a><br> View project at: <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction' target=\"_blank\">https://wandb.ai/kevinxli/cs224w-stock-market-prediction</a><br>Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Find logs at: <code>./wandb/run-20250119_220906-bpsv7yp4/logs</code>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Tracking run with wandb version 0.19.2"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Run data is saved locally in <code>/content/wandb/run-20250119_222056-uzdblt4s</code>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Syncing run <strong><a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/uzdblt4s' target=\"_blank\">DGT_pcc_local_lr_0.01</a></strong> to <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/developer-guide' target=\"_blank\">docs</a>)<br>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
" View project at <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction' target=\"_blank\">https://wandb.ai/kevinxli/cs224w-stock-market-prediction</a>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
" View run at <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/uzdblt4s' target=\"_blank\">https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/uzdblt4s</a>"
]
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"DGT_pcc_local epoch 0 val/rmse: 0.8985436418280034 val/mae: 0.550024151802063\n",
"DGT_pcc_local epoch 10 val/rmse: 0.6056835654869687 val/mae: 0.4801201820373535\n",
"DGT_pcc_local epoch 20 val/rmse: 0.3533077155504064 val/mae: 0.2916525900363922\n",
"DGT_pcc_local epoch 30 val/rmse: 0.15869354246609046 val/mae: 0.10822608321905136\n",
"DGT_pcc_local epoch 40 val/rmse: 0.1591691908856734 val/mae: 0.11498145759105682\n",
"DGT_pcc_local epoch 50 val/rmse: 0.13670301346813016 val/mae: 0.08955355733633041\n",
"DGT_pcc_local epoch 60 val/rmse: 0.15843628549956776 val/mae: 0.09703715145587921\n",
"DGT_pcc_local epoch 70 val/rmse: 0.12820453820853162 val/mae: 0.08339686691761017\n",
"DGT_pcc_local epoch 80 val/rmse: 0.24204782334404729 val/mae: 0.20554694533348083\n",
"DGT_pcc_local epoch 90 val/rmse: 0.17953268928714503 val/mae: 0.12237368524074554\n",
"DGT_pcc_local epoch 99 val/rmse: 0.11312804879247905 val/mae: 0.07187668234109879\n",
"DGT_pcc_local lr: 0.01 test/rmse: 0.2941033328303634 test/mae: 0.08675127476453781\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": []
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"<br> <style><br> .wandb-row {<br> display: flex;<br> flex-direction: row;<br> flex-wrap: wrap;<br> justify-content: flex-start;<br> width: 100%;<br> }<br> .wandb-col {<br> display: flex;<br> flex-direction: column;<br> flex-basis: 100%;<br> flex: 1;<br> padding: 10px;<br> }<br> </style><br><div class=\"wandb-row\"><div class=\"wandb-col\"><h3>Run history:</h3><br/><table class=\"wandb\"><tr><td>epoch</td><td>▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▃▄▄▄▄▅▅▅▅▅▅▅▅▆▆▆▆▇▇█████</td></tr><tr><td>step</td><td>▃▃█▅▂▆▃▃▁▅▅▃▇▃█▁▃▁▅▅▅█▅█▅▁▅▅▅▇▂▅▇▁▄▇▇▃▃▁</td></tr><tr><td>test/mae</td><td>▁</td></tr><tr><td>test/rmse</td><td>▁</td></tr><tr><td>train/loss</td><td>█▇▄▄▆▅▂▅▂▃▂▁▂▂▁▂▂▂▂▂▁▂▁▁▁▁▂▁▂▁▁▂▂▃▂▁▁▃▁▁</td></tr><tr><td>val/best_mae</td><td>▁</td></tr><tr><td>val/best_rmse</td><td>▁</td></tr><tr><td>val/mae</td><td>█▇▄▂▂▁▁▁▃▂▁</td></tr><tr><td>val/rmse</td><td>█▅▃▁▁▁▁▁▂▂▁</td></tr></table><br/></div><div class=\"wandb-col\"><h3>Run summary:</h3><br/><table class=\"wandb\"><tr><td>epoch</td><td>99</td></tr><tr><td>step</td><td>14</td></tr><tr><td>test/mae</td><td>0.08675</td></tr><tr><td>test/rmse</td><td>0.2941</td></tr><tr><td>train/loss</td><td>0.01151</td></tr><tr><td>val/best_mae</td><td>0.07188</td></tr><tr><td>val/best_rmse</td><td>0.11313</td></tr><tr><td>val/mae</td><td>0.07188</td></tr><tr><td>val/rmse</td><td>0.11313</td></tr></table><br/></div></div>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
" View run <strong style=\"color:#cdcd00\">DGT_pcc_local_lr_0.01</strong> at: <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/uzdblt4s' target=\"_blank\">https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/uzdblt4s</a><br> View project at: <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction' target=\"_blank\">https://wandb.ai/kevinxli/cs224w-stock-market-prediction</a><br>Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Find logs at: <code>./wandb/run-20250119_222056-uzdblt4s/logs</code>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Tracking run with wandb version 0.19.2"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Run data is saved locally in <code>/content/wandb/run-20250119_223303-f3qg6gdm</code>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Syncing run <strong><a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/f3qg6gdm' target=\"_blank\">DGT_pcc_local_lr_0.1</a></strong> to <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/developer-guide' target=\"_blank\">docs</a>)<br>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
" View project at <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction' target=\"_blank\">https://wandb.ai/kevinxli/cs224w-stock-market-prediction</a>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
" View run at <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/f3qg6gdm' target=\"_blank\">https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/f3qg6gdm</a>"
]
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"DGT_pcc_local epoch 0 val/rmse: 0.6162145164131906 val/mae: 0.46609604358673096\n",
"DGT_pcc_local epoch 10 val/rmse: 0.7962484327387693 val/mae: 0.25552991032600403\n",
"DGT_pcc_local epoch 20 val/rmse: 0.761364233218597 val/mae: 0.2903038263320923\n",
"DGT_pcc_local epoch 30 val/rmse: 0.6128863138177298 val/mae: 0.22324974834918976\n",
"DGT_pcc_local epoch 40 val/rmse: 0.620641483729527 val/mae: 0.20450033247470856\n",
"DGT_pcc_local epoch 50 val/rmse: 0.656507418871679 val/mae: 0.1447349190711975\n",
"DGT_pcc_local epoch 60 val/rmse: 0.762964804266313 val/mae: 0.1745220124721527\n",
"DGT_pcc_local epoch 70 val/rmse: 0.7329783042621677 val/mae: 0.16009120643138885\n",
"DGT_pcc_local epoch 80 val/rmse: 0.7164827744808504 val/mae: 0.1376502513885498\n",
"DGT_pcc_local epoch 90 val/rmse: 0.7559001428615346 val/mae: 0.20387840270996094\n",
"DGT_pcc_local epoch 99 val/rmse: 0.7303876780843128 val/mae: 0.18701553344726562\n",
"DGT_pcc_local lr: 0.1 test/rmse: 2.506897376109701 test/mae: 0.5010432004928589\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": []
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"<br> <style><br> .wandb-row {<br> display: flex;<br> flex-direction: row;<br> flex-wrap: wrap;<br> justify-content: flex-start;<br> width: 100%;<br> }<br> .wandb-col {<br> display: flex;<br> flex-direction: column;<br> flex-basis: 100%;<br> flex: 1;<br> padding: 10px;<br> }<br> </style><br><div class=\"wandb-row\"><div class=\"wandb-col\"><h3>Run history:</h3><br/><table class=\"wandb\"><tr><td>epoch</td><td>▁▁▁▁▂▂▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▇███</td></tr><tr><td>step</td><td>▅▇▅█▇█▁█▃▇▂▃▇▃▃█▇▃▂▇▇▁▅▁▃▃▆▁▄█▃▃▄▂▅▃▁▇█▃</td></tr><tr><td>test/mae</td><td>▁</td></tr><tr><td>test/rmse</td><td>▁</td></tr><tr><td>train/loss</td><td>█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁</td></tr><tr><td>val/best_mae</td><td>▁</td></tr><tr><td>val/best_rmse</td><td>▁</td></tr><tr><td>val/mae</td><td>█▄▄▃▂▁▂▁▁▂▂</td></tr><tr><td>val/rmse</td><td>▁█▇▁▁▃▇▆▅▆▅</td></tr></table><br/></div><div class=\"wandb-col\"><h3>Run summary:</h3><br/><table class=\"wandb\"><tr><td>epoch</td><td>99</td></tr><tr><td>step</td><td>14</td></tr><tr><td>test/mae</td><td>0.50104</td></tr><tr><td>test/rmse</td><td>2.5069</td></tr><tr><td>train/loss</td><td>0.01632</td></tr><tr><td>val/best_mae</td><td>0.22325</td></tr><tr><td>val/best_rmse</td><td>0.61289</td></tr><tr><td>val/mae</td><td>0.18702</td></tr><tr><td>val/rmse</td><td>0.73039</td></tr></table><br/></div></div>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
" View run <strong style=\"color:#cdcd00\">DGT_pcc_local_lr_0.1</strong> at: <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/f3qg6gdm' target=\"_blank\">https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/f3qg6gdm</a><br> View project at: <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction' target=\"_blank\">https://wandb.ai/kevinxli/cs224w-stock-market-prediction</a><br>Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Find logs at: <code>./wandb/run-20250119_223303-f3qg6gdm/logs</code>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Tracking run with wandb version 0.19.2"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Run data is saved locally in <code>/content/wandb/run-20250119_224536-pioxs823</code>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Syncing run <strong><a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/pioxs823' target=\"_blank\">DGT_pcc_dual_lr_0.01</a></strong> to <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/developer-guide' target=\"_blank\">docs</a>)<br>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
" View project at <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction' target=\"_blank\">https://wandb.ai/kevinxli/cs224w-stock-market-prediction</a>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
" View run at <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/pioxs823' target=\"_blank\">https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/pioxs823</a>"
]
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"DGT_pcc_dual epoch 0 val/rmse: 1.0027066199069214 val/mae: 0.531448245048523\n",
"DGT_pcc_dual epoch 10 val/rmse: 0.5686480912932805 val/mae: 0.30941909551620483\n",
"DGT_pcc_dual epoch 20 val/rmse: 0.3920950693278707 val/mae: 0.14852195978164673\n",
"DGT_pcc_dual epoch 30 val/rmse: 0.42255946547839507 val/mae: 0.16893088817596436\n",
"DGT_pcc_dual epoch 40 val/rmse: 0.36900113008165863 val/mae: 0.13784931600093842\n",
"DGT_pcc_dual epoch 50 val/rmse: 0.39474134233486063 val/mae: 0.17326299846172333\n",
"DGT_pcc_dual epoch 60 val/rmse: 0.34443931390973087 val/mae: 0.1496657431125641\n",
"DGT_pcc_dual epoch 70 val/rmse: 0.4169347218709851 val/mae: 0.26017606258392334\n",
"DGT_pcc_dual epoch 80 val/rmse: 0.16448492209415183 val/mae: 0.12450996786355972\n",
"DGT_pcc_dual epoch 90 val/rmse: 0.205356155120654 val/mae: 0.12474433332681656\n",
"DGT_pcc_dual epoch 99 val/rmse: 0.1836568034230089 val/mae: 0.10215266048908234\n",
"DGT_pcc_dual lr: 0.01 test/rmse: 0.29678211515858255 test/mae: 0.15739485621452332\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": []
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"<br> <style><br> .wandb-row {<br> display: flex;<br> flex-direction: row;<br> flex-wrap: wrap;<br> justify-content: flex-start;<br> width: 100%;<br> }<br> .wandb-col {<br> display: flex;<br> flex-direction: column;<br> flex-basis: 100%;<br> flex: 1;<br> padding: 10px;<br> }<br> </style><br><div class=\"wandb-row\"><div class=\"wandb-col\"><h3>Run history:</h3><br/><table class=\"wandb\"><tr><td>epoch</td><td>▁▁▁▁▁▂▂▂▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇▇▇██</td></tr><tr><td>step</td><td>▁▇▄▇▃▃▇▇▂▁▅▃▅▃▁▁▃▇▇█▁▃▅▅▄▃▃▁▅▅▁▇▇█▃▇▁▁▂▃</td></tr><tr><td>test/mae</td><td>▁</td></tr><tr><td>test/rmse</td><td>▁</td></tr><tr><td>train/loss</td><td>█▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▂▁▁▁▁▂▁▁▁▁▁</td></tr><tr><td>val/best_mae</td><td>▁</td></tr><tr><td>val/best_rmse</td><td>▁</td></tr><tr><td>val/mae</td><td>█▄▂▂▂▂▂▄▁▁▁</td></tr><tr><td>val/rmse</td><td>█▄▃▃▃▃▃▃▁▁▁</td></tr></table><br/></div><div class=\"wandb-col\"><h3>Run summary:</h3><br/><table class=\"wandb\"><tr><td>epoch</td><td>99</td></tr><tr><td>step</td><td>14</td></tr><tr><td>test/mae</td><td>0.15739</td></tr><tr><td>test/rmse</td><td>0.29678</td></tr><tr><td>train/loss</td><td>0.01558</td></tr><tr><td>val/best_mae</td><td>0.12451</td></tr><tr><td>val/best_rmse</td><td>0.16448</td></tr><tr><td>val/mae</td><td>0.10215</td></tr><tr><td>val/rmse</td><td>0.18366</td></tr></table><br/></div></div>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
" View run <strong style=\"color:#cdcd00\">DGT_pcc_dual_lr_0.01</strong> at: <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/pioxs823' target=\"_blank\">https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/pioxs823</a><br> View project at: <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction' target=\"_blank\">https://wandb.ai/kevinxli/cs224w-stock-market-prediction</a><br>Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Find logs at: <code>./wandb/run-20250119_224536-pioxs823/logs</code>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Tracking run with wandb version 0.19.2"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Run data is saved locally in <code>/content/wandb/run-20250119_225959-6w2qd34h</code>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Syncing run <strong><a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/6w2qd34h' target=\"_blank\">DGT_pcc_dual_lr_0.1</a></strong> to <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/developer-guide' target=\"_blank\">docs</a>)<br>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
" View project at <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction' target=\"_blank\">https://wandb.ai/kevinxli/cs224w-stock-market-prediction</a>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
" View run at <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/6w2qd34h' target=\"_blank\">https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/6w2qd34h</a>"
]
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"DGT_pcc_dual epoch 0 val/rmse: 0.5137868252199914 val/mae: 0.3425385653972626\n",
"DGT_pcc_dual epoch 10 val/rmse: 0.6532990547679005 val/mae: 0.16436786949634552\n",
"DGT_pcc_dual epoch 20 val/rmse: 0.5517294817377448 val/mae: 0.1413598507642746\n",
"DGT_pcc_dual epoch 30 val/rmse: 0.5629765558138379 val/mae: 0.13847942650318146\n",
"DGT_pcc_dual epoch 40 val/rmse: 0.7624036790638924 val/mae: 0.2591487765312195\n",
"DGT_pcc_dual epoch 50 val/rmse: 0.7292081594014849 val/mae: 0.3468876779079437\n",
"DGT_pcc_dual epoch 60 val/rmse: 0.8175930905783602 val/mae: 0.6307493448257446\n",
"DGT_pcc_dual epoch 70 val/rmse: 0.9565153327287272 val/mae: 0.2979082763195038\n",
"DGT_pcc_dual epoch 80 val/rmse: 0.7688234433703265 val/mae: 0.20415645837783813\n",
"DGT_pcc_dual epoch 90 val/rmse: 0.6282031708086251 val/mae: 0.21568505465984344\n",
"DGT_pcc_dual epoch 99 val/rmse: 0.6459992769815988 val/mae: 0.15335515141487122\n",
"DGT_pcc_dual lr: 0.1 test/rmse: 1.0980833044153202 test/mae: 0.5658656358718872\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": []
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"<br> <style><br> .wandb-row {<br> display: flex;<br> flex-direction: row;<br> flex-wrap: wrap;<br> justify-content: flex-start;<br> width: 100%;<br> }<br> .wandb-col {<br> display: flex;<br> flex-direction: column;<br> flex-basis: 100%;<br> flex: 1;<br> padding: 10px;<br> }<br> </style><br><div class=\"wandb-row\"><div class=\"wandb-col\"><h3>Run history:</h3><br/><table class=\"wandb\"><tr><td>epoch</td><td>▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▃▄▄▄▄▄▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇███</td></tr><tr><td>step</td><td>█▅▆▅▄▅▇▇▃▃▃▁▃▃▃▂▇▅▃▇▃▅▆▂▅▅▂▅▄▃▄█▃▅▃▁▇▇▆█</td></tr><tr><td>test/mae</td><td>▁</td></tr><tr><td>test/rmse</td><td>▁</td></tr><tr><td>train/loss</td><td>▇▂▃▁▁▁▃▁▁▁▁▁▁▃▂▂▂▂▃▁▁▂▁▂▂▂▂█▃▁▃▂▃▁▁▃▂▂▂▂</td></tr><tr><td>val/best_mae</td><td>▁</td></tr><tr><td>val/best_rmse</td><td>▁</td></tr><tr><td>val/mae</td><td>▄▁▁▁▃▄█▃▂▂▁</td></tr><tr><td>val/rmse</td><td>▁▃▂▂▅▄▆█▅▃▃</td></tr></table><br/></div><div class=\"wandb-col\"><h3>Run summary:</h3><br/><table class=\"wandb\"><tr><td>epoch</td><td>99</td></tr><tr><td>step</td><td>14</td></tr><tr><td>test/mae</td><td>0.56587</td></tr><tr><td>test/rmse</td><td>1.09808</td></tr><tr><td>train/loss</td><td>0.03013</td></tr><tr><td>val/best_mae</td><td>0.34254</td></tr><tr><td>val/best_rmse</td><td>0.51379</td></tr><tr><td>val/mae</td><td>0.15336</td></tr><tr><td>val/rmse</td><td>0.646</td></tr></table><br/></div></div>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
" View run <strong style=\"color:#cdcd00\">DGT_pcc_dual_lr_0.1</strong> at: <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/6w2qd34h' target=\"_blank\">https://wandb.ai/kevinxli/cs224w-stock-market-prediction/runs/6w2qd34h</a><br> View project at: <a href='https://wandb.ai/kevinxli/cs224w-stock-market-prediction' target=\"_blank\">https://wandb.ai/kevinxli/cs224w-stock-market-prediction</a><br>Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Find logs at: <code>./wandb/run-20250119_225959-6w2qd34h/logs</code>"
]
},
"metadata": {}
}
],
"source": [
"import wandb\n",
"import os\n",
"\n",
"# Train a model with the input configs on train_samples for num_epochs under the learning rate lr\n",
"# You can pass track_with_wandb=True to trace the entire training process with Weights and Biases\n",
"def train(gnn, use_spatial, corr_name, corr_scope, train_samples, val_samples, num_epochs, lr, track_with_wandb):\n",
" os.makedirs(f'{workdir}/models', exist_ok=True)\n",
"\n",
" model = get_model(gnn, use_spatial, corr_name, corr_scope, lr, load_weights=False)\n",
"\n",
" optimizer = torch.optim.Adam(model.parameters(), lr=lr)\n",
" best_val_rmse = float('inf')\n",
" best_val_mae = float('inf')\n",
" eval_per_epoch = 10\n",
"\n",
" if track_with_wandb:\n",
" wandb_run = wandb.init(project=\"cs224w-stock-market-prediction\",\n",
" name=f'{model.model_name()}_lr_{lr}',\n",
" config={\n",
" \"corr_name\": corr_name,\n",
" \"corr_scope\": corr_scope,\n",
" \"learning_rate\": lr,\n",
" \"epochs\": num_epochs,\n",
" \"architecture\": gnn.__name__,\n",
" \"use_spatial\": use_spatial,\n",
" },\n",
" reinit=True,\n",
" )\n",
"\n",
" for epoch in range(num_epochs):\n",
" model.train()\n",
" train_loss = 0\n",
" for step, snapshot in enumerate(train_samples):\n",
" optimizer.zero_grad()\n",
" X = snapshot.x\n",
" y_hats, _ = model(X.transpose(0, 1), snapshot.edge_index, snapshot.edge_attr, hidden=None)\n",
" loss = F.mse_loss(y_hats.squeeze(), snapshot.y.to(model.model_device()))\n",
" train_loss += loss.item()\n",
" loss.backward()\n",
" optimizer.step()\n",
" if track_with_wandb:\n",
" wandb.log({\"epoch\": epoch, \"step\": step, \"train/loss\": loss.item() })\n",
" train_loss /= len(train_samples)\n",
"\n",
" if epoch % eval_per_epoch == 0 or epoch == num_epochs - 1:\n",
" result = eval(model, val_samples)\n",
" val_rmse = result['rmse']\n",
" val_mae = result['mae']\n",
" print(f'{model.model_name()} epoch {epoch} val/rmse: {val_rmse} val/mae: {val_mae}')\n",
" if track_with_wandb:\n",
" wandb.log({\"epoch\": epoch, \"val/rmse\": val_rmse, \"val/mae\": val_mae })\n",
" if val_rmse < best_val_rmse:\n",
" best_val_rmse = val_rmse\n",
" best_val_mae = val_mae\n",
" torch.save(model.state_dict(), f'{workdir}/models/{model.model_name()}_lr_{lr}.pth')\n",
" if track_with_wandb:\n",
" wandb.log({\"val/best_rmse\": best_val_rmse, \"val/best_mae\": best_val_mae })\n",
" return wandb_run\n",
"\n",
"\n",
"def run(args):\n",
" config, num_epochs, track_with_wandb = args\n",
" gnn, use_spatial, corr_name, corr_scope = config\n",
" dataset = get_dataset(corr_name, corr_scope)\n",
" # Do a grid search over learning rate. We found that models are sensitive to lr so we need to try different options.\n",
" for lr in [0.01, 0.1]:\n",
" wandb_run = train(gnn, use_spatial, corr_name, corr_scope, dataset['train_samples'], dataset['val_samples'], num_epochs, lr, track_with_wandb)\n",
" # Test\n",
" best_model = get_model(gnn, use_spatial, corr_name, corr_scope, lr, load_weights=True)\n",
" result = eval(best_model, dataset['test_samples'])\n",
" test_rmse = result['rmse']\n",
" test_mae = result['mae']\n",
" print(f'{best_model.model_name()} lr: {lr} test/rmse: {test_rmse} test/mae: {test_mae}')\n",
" if track_with_wandb:\n",
" wandb.log({\"test/rmse\": test_rmse, \"test/mae\": test_mae })\n",
" wandb_run.finish()\n",
"\n",
"# List all the model variants for the experiment\n",
"model_configs = [(GRU, False, None, None),\n",
" (DGT, False, None, None),\n",
" (DGT, True, None, None),\n",
" (DGT, True, 'mi', 'global'),\n",
" (DGT, True, 'mi', 'local'),\n",
" (DGT, True, 'mi', 'dual'),\n",
" (DGT, True, 'pcc', 'global'),\n",
" (DGT, True, 'pcc', 'local'),\n",
" (DGT, True, 'pcc', 'dual'),\n",
" ]\n",
"\n",
"num_epochs = 100\n",
"track_with_wandb = True\n",
"\n",
"if track_with_wandb:\n",
" wandb.login()\n",
"\n",
"_ = list(map(run, [(config, num_epochs, track_with_wandb) for config in model_configs]))"
]
},
{
"cell_type": "markdown",
"source": [
"# Results Show Local Mutual Information Performs the Best, Followed by Dual Pearson"
],
"metadata": {
"id": "AvlH4SXnPK7X"
}
},
{
"cell_type": "code",
"source": [
"# Function to test a bunch of models given by model_configs on the test set\n",
"def test(model_configs):\n",
" results = {}\n",
" for config in model_configs:\n",
" (gnn, use_spatial, corr_name, corr_scope, lr) = config\n",
" eval_dataset = get_dataset(corr_name, corr_scope)['test_samples']\n",
" model = get_model(gnn, use_spatial, corr_name, corr_scope, lr=lr, load_weights=True)\n",
" if model is None:\n",
" continue\n",
" results[config] = eval(model=model, eval_dataset=eval_dataset)\n",
" return results\n",
"\n",
"# Test each model under its best learning rate based on validation performance\n",
"model_configs = [(GRU, False, None, None, 0.01),\n",
" (DGT, False, None, None, 0.1),\n",
" (DGT, True, None, None, 0.01),\n",
" (DGT, True, 'mi', 'global', 0.01),\n",
" (DGT, True, 'mi', 'local', 0.01),\n",
" (DGT, True, 'mi', 'dual', 0.01),\n",
" (DGT, True, 'pcc', 'global', 0.01),\n",
" (DGT, True, 'pcc', 'local', 0.01),\n",
" (DGT, True, 'pcc', 'dual', 0.01),\n",
" ]\n",
"test_results = test(model_configs)"
],
"metadata": {
"id": "xlwKYtG4PhYp"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Test the models and show a table of results\n",
"test_results_df = []\n",
"for config, result in test_results.items():\n",
" (arch, use_spatial, corr_name, corr_scope, lr) = config\n",
" test_results_df.append({'Architecture': arch.__name__,\n",
" 'Use Spatial': use_spatial,\n",
" 'Correlation': corr_name,\n",
" 'Scope': corr_scope,\n",
" 'RMSE': result['rmse'],\n",
" 'MAE': result['mae']})\n",
"\n",
"test_results_df = pd.DataFrame(test_results_df)\n",
"test_results_df = test_results_df.sort_values(by='RMSE', ascending=False)\n",
"print('Test results sorted in descending RMSE (Lower the better)')\n",
"display(test_results_df)"
],
"metadata": {
"id": "qX7TIllclzN1",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 349
},
"outputId": "8afde05b-f7db-4424-c921-b084467e30d1"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Test results sorted in descending RMSE (Lower the better)\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
" Architecture Use Spatial Correlation Scope RMSE MAE\n",
"0 GRU False None None 3.116608 0.407619\n",
"2 DGT True None None 1.473705 0.292516\n",
"6 DGT True pcc global 0.643023 0.147955\n",
"3 DGT True mi global 0.465103 0.114918\n",
"1 DGT False None None 0.432899 0.181658\n",
"4 DGT True mi local 0.324191 0.133980\n",
"8 DGT True pcc dual 0.296782 0.157395\n",
"7 DGT True pcc local 0.294103 0.086751\n",
"5 DGT True mi dual 0.259889 0.099287"
],
"text/html": [
"\n",
" <div id=\"df-92cf656f-8338-4337-9e16-7c39071d5f59\" class=\"colab-df-container\">\n",
" <div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Architecture</th>\n",
" <th>Use Spatial</th>\n",
" <th>Correlation</th>\n",
" <th>Scope</th>\n",
" <th>RMSE</th>\n",
" <th>MAE</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>GRU</td>\n",
" <td>False</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>3.116608</td>\n",
" <td>0.407619</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>DGT</td>\n",
" <td>True</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>1.473705</td>\n",
" <td>0.292516</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>DGT</td>\n",
" <td>True</td>\n",
" <td>pcc</td>\n",
" <td>global</td>\n",
" <td>0.643023</td>\n",
" <td>0.147955</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>DGT</td>\n",
" <td>True</td>\n",
" <td>mi</td>\n",
" <td>global</td>\n",
" <td>0.465103</td>\n",
" <td>0.114918</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>DGT</td>\n",
" <td>False</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>0.432899</td>\n",
" <td>0.181658</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>DGT</td>\n",
" <td>True</td>\n",
" <td>mi</td>\n",
" <td>local</td>\n",
" <td>0.324191</td>\n",
" <td>0.133980</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>DGT</td>\n",
" <td>True</td>\n",
" <td>pcc</td>\n",
" <td>dual</td>\n",
" <td>0.296782</td>\n",
" <td>0.157395</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>DGT</td>\n",
" <td>True</td>\n",
" <td>pcc</td>\n",
" <td>local</td>\n",
" <td>0.294103</td>\n",
" <td>0.086751</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>DGT</td>\n",
" <td>True</td>\n",
" <td>mi</td>\n",
" <td>dual</td>\n",
" <td>0.259889</td>\n",
" <td>0.099287</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>\n",
" <div class=\"colab-df-buttons\">\n",
"\n",
" <div class=\"colab-df-container\">\n",
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-92cf656f-8338-4337-9e16-7c39071d5f59')\"\n",
" title=\"Convert this dataframe to an interactive table.\"\n",
" style=\"display:none;\">\n",
"\n",
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\" viewBox=\"0 -960 960 960\">\n",
" <path d=\"M120-120v-720h720v720H120Zm60-500h600v-160H180v160Zm220 220h160v-160H400v160Zm0 220h160v-160H400v160ZM180-400h160v-160H180v160Zm440 0h160v-160H620v160ZM180-180h160v-160H180v160Zm440 0h160v-160H620v160Z\"/>\n",
" </svg>\n",
" </button>\n",
"\n",
" <style>\n",
" .colab-df-container {\n",
" display:flex;\n",
" gap: 12px;\n",
" }\n",
"\n",
" .colab-df-convert {\n",
" background-color: #E8F0FE;\n",
" border: none;\n",
" border-radius: 50%;\n",
" cursor: pointer;\n",
" display: none;\n",
" fill: #1967D2;\n",
" height: 32px;\n",
" padding: 0 0 0 0;\n",
" width: 32px;\n",
" }\n",
"\n",
" .colab-df-convert:hover {\n",
" background-color: #E2EBFA;\n",
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
" fill: #174EA6;\n",
" }\n",
"\n",
" .colab-df-buttons div {\n",
" margin-bottom: 4px;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert {\n",
" background-color: #3B4455;\n",
" fill: #D2E3FC;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert:hover {\n",
" background-color: #434B5C;\n",
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
" fill: #FFFFFF;\n",
" }\n",
" </style>\n",
"\n",
" <script>\n",
" const buttonEl =\n",
" document.querySelector('#df-92cf656f-8338-4337-9e16-7c39071d5f59 button.colab-df-convert');\n",
" buttonEl.style.display =\n",
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
"\n",
" async function convertToInteractive(key) {\n",
" const element = document.querySelector('#df-92cf656f-8338-4337-9e16-7c39071d5f59');\n",
" const dataTable =\n",
" await google.colab.kernel.invokeFunction('convertToInteractive',\n",
" [key], {});\n",
" if (!dataTable) return;\n",
"\n",
" const docLinkHtml = 'Like what you see? Visit the ' +\n",
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
" + ' to learn more about interactive tables.';\n",
" element.innerHTML = '';\n",
" dataTable['output_type'] = 'display_data';\n",
" await google.colab.output.renderOutput(dataTable, element);\n",
" const docLink = document.createElement('div');\n",
" docLink.innerHTML = docLinkHtml;\n",
" element.appendChild(docLink);\n",
" }\n",
" </script>\n",
" </div>\n",
"\n",
"\n",
"<div id=\"df-edacbc5c-7fbd-4388-936a-c83a6874e21e\">\n",
" <button class=\"colab-df-quickchart\" onclick=\"quickchart('df-edacbc5c-7fbd-4388-936a-c83a6874e21e')\"\n",
" title=\"Suggest charts\"\n",
" style=\"display:none;\">\n",
"\n",
"<svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
" width=\"24px\">\n",
" <g>\n",
" <path d=\"M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z\"/>\n",
" </g>\n",
"</svg>\n",
" </button>\n",
"\n",
"<style>\n",
" .colab-df-quickchart {\n",
" --bg-color: #E8F0FE;\n",
" --fill-color: #1967D2;\n",
" --hover-bg-color: #E2EBFA;\n",
" --hover-fill-color: #174EA6;\n",
" --disabled-fill-color: #AAA;\n",
" --disabled-bg-color: #DDD;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-quickchart {\n",
" --bg-color: #3B4455;\n",
" --fill-color: #D2E3FC;\n",
" --hover-bg-color: #434B5C;\n",
" --hover-fill-color: #FFFFFF;\n",
" --disabled-bg-color: #3B4455;\n",
" --disabled-fill-color: #666;\n",
" }\n",
"\n",
" .colab-df-quickchart {\n",
" background-color: var(--bg-color);\n",
" border: none;\n",
" border-radius: 50%;\n",
" cursor: pointer;\n",
" display: none;\n",
" fill: var(--fill-color);\n",
" height: 32px;\n",
" padding: 0;\n",
" width: 32px;\n",
" }\n",
"\n",
" .colab-df-quickchart:hover {\n",
" background-color: var(--hover-bg-color);\n",
" box-shadow: 0 1px 2px rgba(60, 64, 67, 0.3), 0 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
" fill: var(--button-hover-fill-color);\n",
" }\n",
"\n",
" .colab-df-quickchart-complete:disabled,\n",
" .colab-df-quickchart-complete:disabled:hover {\n",
" background-color: var(--disabled-bg-color);\n",
" fill: var(--disabled-fill-color);\n",
" box-shadow: none;\n",
" }\n",
"\n",
" .colab-df-spinner {\n",
" border: 2px solid var(--fill-color);\n",
" border-color: transparent;\n",
" border-bottom-color: var(--fill-color);\n",
" animation:\n",
" spin 1s steps(1) infinite;\n",
" }\n",
"\n",
" @keyframes spin {\n",
" 0% {\n",
" border-color: transparent;\n",
" border-bottom-color: var(--fill-color);\n",
" border-left-color: var(--fill-color);\n",
" }\n",
" 20% {\n",
" border-color: transparent;\n",
" border-left-color: var(--fill-color);\n",
" border-top-color: var(--fill-color);\n",
" }\n",
" 30% {\n",
" border-color: transparent;\n",
" border-left-color: var(--fill-color);\n",
" border-top-color: var(--fill-color);\n",
" border-right-color: var(--fill-color);\n",
" }\n",
" 40% {\n",
" border-color: transparent;\n",
" border-right-color: var(--fill-color);\n",
" border-top-color: var(--fill-color);\n",
" }\n",
" 60% {\n",
" border-color: transparent;\n",
" border-right-color: var(--fill-color);\n",
" }\n",
" 80% {\n",
" border-color: transparent;\n",
" border-right-color: var(--fill-color);\n",
" border-bottom-color: var(--fill-color);\n",
" }\n",
" 90% {\n",
" border-color: transparent;\n",
" border-bottom-color: var(--fill-color);\n",
" }\n",
" }\n",
"</style>\n",
"\n",
" <script>\n",
" async function quickchart(key) {\n",
" const quickchartButtonEl =\n",
" document.querySelector('#' + key + ' button');\n",
" quickchartButtonEl.disabled = true; // To prevent multiple clicks.\n",
" quickchartButtonEl.classList.add('colab-df-spinner');\n",
" try {\n",
" const charts = await google.colab.kernel.invokeFunction(\n",
" 'suggestCharts', [key], {});\n",
" } catch (error) {\n",
" console.error('Error during call to suggestCharts:', error);\n",
" }\n",
" quickchartButtonEl.classList.remove('colab-df-spinner');\n",
" quickchartButtonEl.classList.add('colab-df-quickchart-complete');\n",
" }\n",
" (() => {\n",
" let quickchartButtonEl =\n",
" document.querySelector('#df-edacbc5c-7fbd-4388-936a-c83a6874e21e button');\n",
" quickchartButtonEl.style.display =\n",
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
" })();\n",
" </script>\n",
"</div>\n",
"\n",
" <div id=\"id_c6abb6c9-774e-4a19-946a-c37573a17e51\">\n",
" <style>\n",
" .colab-df-generate {\n",
" background-color: #E8F0FE;\n",
" border: none;\n",
" border-radius: 50%;\n",
" cursor: pointer;\n",
" display: none;\n",
" fill: #1967D2;\n",
" height: 32px;\n",
" padding: 0 0 0 0;\n",
" width: 32px;\n",
" }\n",
"\n",
" .colab-df-generate:hover {\n",
" background-color: #E2EBFA;\n",
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
" fill: #174EA6;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-generate {\n",
" background-color: #3B4455;\n",
" fill: #D2E3FC;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-generate:hover {\n",
" background-color: #434B5C;\n",
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
" fill: #FFFFFF;\n",
" }\n",
" </style>\n",
" <button class=\"colab-df-generate\" onclick=\"generateWithVariable('test_results_df')\"\n",
" title=\"Generate code using this dataframe.\"\n",
" style=\"display:none;\">\n",
"\n",
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
" width=\"24px\">\n",
" <path d=\"M7,19H8.4L18.45,9,17,7.55,7,17.6ZM5,21V16.75L18.45,3.32a2,2,0,0,1,2.83,0l1.4,1.43a1.91,1.91,0,0,1,.58,1.4,1.91,1.91,0,0,1-.58,1.4L9.25,21ZM18.45,9,17,7.55Zm-12,3A5.31,5.31,0,0,0,4.9,8.1,5.31,5.31,0,0,0,1,6.5,5.31,5.31,0,0,0,4.9,4.9,5.31,5.31,0,0,0,6.5,1,5.31,5.31,0,0,0,8.1,4.9,5.31,5.31,0,0,0,12,6.5,5.46,5.46,0,0,0,6.5,12Z\"/>\n",
" </svg>\n",
" </button>\n",
" <script>\n",
" (() => {\n",
" const buttonEl =\n",
" document.querySelector('#id_c6abb6c9-774e-4a19-946a-c37573a17e51 button.colab-df-generate');\n",
" buttonEl.style.display =\n",
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
"\n",
" buttonEl.onclick = () => {\n",
" google.colab.notebook.generateWithVariable('test_results_df');\n",
" }\n",
" })();\n",
" </script>\n",
" </div>\n",
"\n",
" </div>\n",
" </div>\n"
],
"application/vnd.google.colaboratory.intrinsic+json": {
"type": "dataframe",
"variable_name": "test_results_df",
"summary": "{\n \"name\": \"test_results_df\",\n \"rows\": 9,\n \"fields\": [\n {\n \"column\": \"Architecture\",\n \"properties\": {\n \"dtype\": \"category\",\n \"num_unique_values\": 2,\n \"samples\": [\n \"DGT\",\n \"GRU\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"Use Spatial\",\n \"properties\": {\n \"dtype\": \"boolean\",\n \"num_unique_values\": 2,\n \"samples\": [\n true,\n false\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"Correlation\",\n \"properties\": {\n \"dtype\": \"category\",\n \"num_unique_values\": 2,\n \"samples\": [\n \"mi\",\n \"pcc\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"Scope\",\n \"properties\": {\n \"dtype\": \"category\",\n \"num_unique_values\": 3,\n \"samples\": [\n \"global\",\n \"local\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"RMSE\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.9432741552197175,\n \"min\": 0.2598891143975098,\n \"max\": 3.1166079945745024,\n \"num_unique_values\": 9,\n \"samples\": [\n 0.2941033328303634,\n 1.473704888039203\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"MAE\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.10462648465800446,\n \"min\": 0.08675127476453781,\n \"max\": 0.4076187312602997,\n \"num_unique_values\": 9,\n \"samples\": [\n 0.08675127476453781,\n 0.2925158739089966\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}"
}
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KTzU4f_F0dFI"
},
"source": [
"# Visualize Results"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "xDyg3IUk0dFI"
},
"outputs": [],
"source": [
"df = pd.read_csv(f'{workdir}/sp500.csv')\n",
"df['Date'] = pd.to_datetime(df['Date'])\n",
"df = df.set_index('Date')\n",
"stock_names = df.columns\n",
"stock_lookup = {name: i for i, name in enumerate(stock_names)}"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "3qwQR-z00dFJ"
},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"from cycler import cycler\n",
"\n",
"# Plot regression results on test set\n",
"def plot_regression(configs, labels, fig_name, stock_name, test_results):\n",
" stock_index = stock_lookup[stock_name]\n",
" plt.clf()\n",
" plt.figure(figsize=(10, 6))\n",
" colors = [(0.650, 0.120, 0.240, 0.6), # red\n",
" (0.122, 0.467, 0.706, 0.6), # blue\n",
" (1.000, 0.498, 0.055), # orange\n",
" (0.580, 0.403, 0.741, 0.6), # purple\n",
" ]\n",
" plt.rc('axes', prop_cycle=cycler('color', colors))\n",
"\n",
" for config, label in zip(configs, labels):\n",
" result = test_results[config]\n",
" ys = torch.tensor([y[stock_index] for y in result['ys'].cpu()])\n",
" y_hats = torch.tensor([y_hat[stock_index] for y_hat in result['y_hats'].cpu()])\n",
" x = np.array(range(len(ys)))\n",
" plt.plot(x, y_hats, label=label, linewidth=1)\n",
"\n",
" plt.plot(x, ys, label=\"Real\", color='green')\n",
" plt.legend(fontsize=14)\n",
" plt.title(f'Predicted vs Real {stock_name} Stock Price on Test', fontsize=20)\n",
" plt.xlabel('Days', fontsize=16)\n",
" plt.ylabel('Normalized Price', fontsize=16)\n",
" plt.tick_params(axis='x', labelsize=16)\n",
" plt.tick_params(axis='y', labelsize=16)\n",
" plt.savefig(fig_name)\n",
" plt.show()\n"
]
},
{
"cell_type": "markdown",
"source": [
"## Local Mutual Information Outperforms Global and Dual"
],
"metadata": {
"id": "Cymu6slGN3fj"
}
},
{
"cell_type": "code",
"source": [
"plot_regression([(DGT, True, 'mi', 'global', 0.01),\n",
" (DGT, True, 'mi', 'local', 0.01),\n",
" (DGT, True, 'mi', 'dual', 0.01)\n",
" ],\n",
" ['Global MI with DGT', 'Local MI with DGT', 'Dual MI with DGT'],\n",
" stock_name='AAPL', fig_name=f'{workdir}/sp500_AAPL_MI.png', test_results=test_results)"
],
"metadata": {
"id": "vnTk76ijNpuK",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 601
},
"outputId": "6e9fbd7a-2dfe-44c1-9645-59a3899ba214"
},
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 640x480 with 0 Axes>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 1000x600 with 1 Axes>"
],
"image/png": "\n"
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"source": [
"## Dual Pearson Outperforms Global and Local"
],
"metadata": {
"id": "TQpB3LkkOFTq"
}
},
{
"cell_type": "code",
"source": [
"plot_regression([(DGT, True, 'pcc', 'global', 0.01),\n",
" (DGT, True, 'pcc', 'local', 0.01),\n",
" (DGT, True, 'pcc', 'dual', 0.01)],\n",
" ['Global Pearson with DGT', 'Local Pearson with DGT', 'Dual Pearson with DGT'],\n",
" stock_name='AAPL', fig_name=f'{workdir}/sp500_AAPL_PCC.png', test_results=test_results)"
],
"metadata": {
"id": "u9f7Wu6iNoIF",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 601
},
"outputId": "df692e24-bf65-4358-fe9b-d19a7e9b7067"
},
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 640x480 with 0 Axes>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 1000x600 with 1 Axes>"
],
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment