Skip to content

Instantly share code, notes, and snippets.

@fgolemo
Created March 1, 2022 22:25
Show Gist options
  • Save fgolemo/6144f7b7ddd970d19f6d9804573c5d48 to your computer and use it in GitHub Desktop.
Save fgolemo/6144f7b7ddd970d19f6d9804573c5d48 to your computer and use it in GitHub Desktop.
Toy example for Autobots: Latent Variable Sequential Set Transformers
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"accelerator": "GPU",
"colab": {
"name": "autobot_toy.ipynb",
"provenance": [],
"collapsed_sections": []
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.8"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "EB9MwlqV98w-"
},
"source": [
"# AutoBot Toy Dataset Modelling"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": true,
"id": "6lMLqmTN9i8G"
},
"source": [
"## Generate Small Synthetic Non-linear Particle Accelerator Dataset\n",
"In this section, we generate our tiny toy dataset that showcases AutoBot's ability to model multimodal trajectories. We generate these trajectories by adopting a simple bicycle model that turns with a constant steering angle at a constant speed. This toy dataset not only demonstrates the multimodal trajectories AutoBot generates, but also shows the importance of the entropy regularization term."
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 596
},
"id": "cBIkHktQ9i8K",
"outputId": "a4a79aad-763c-4b8a-dee5-4b9c67856ce1"
},
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"%matplotlib inline\n",
"\n",
"\n",
"l = 2 # length of bycicle\n",
"dt = 0.5\n",
"start_pos = np.array([-0.0, -10.0])\n",
"data = []\n",
"\n",
"# we'll generate a total of 6 trajectories\n",
"\n",
"# we will have 2 trajectories go left, 2 go straight, 2 go right\n",
"phis = [-0.2, 0.0, 0.2]\n",
"\n",
"# the trjeactories will go left/straight/right at one of 2 possible speeds\n",
"speeds = [1.5, 3.0]\n",
"\n",
"configs = np.array(np.meshgrid(phis, speeds)).T.reshape(-1, 2)\n",
"\n",
"fig, ax = plt.subplots(nrows=2, ncols=3, figsize=(15, 10))\n",
"row = 0\n",
"for i, config in enumerate(configs):\n",
" wheel_pos = []\n",
" speed = 3.0\n",
" heading = np.pi / 2\n",
" phi = 0.0\n",
" positions = {\n",
" \"rear\": np.array([0.0, 0.0]) + start_pos,\n",
" \"front\": np.array([l * np.cos(heading), l * np.sin(heading)]) + start_pos\n",
" }\n",
"\n",
" # we are generating trajectories of total length 18, \n",
" # 6 of which will be used as input trajectory and\n",
" # the remaining 12 as output trajectory\n",
" for t in range(18):\n",
" \n",
" # the first 6 steps are fixed to be straight upwards-facing at the given velocity\n",
" if t > 6:\n",
" phi, speed = config\n",
"\n",
" # for the remaining 12 steps, we apply different headings and velocities\n",
" x_v = speed*np.cos(heading)\n",
" y_v = speed*np.sin(heading)\n",
" omega = speed*np.tan(phi)/l\n",
" heading += omega * dt\n",
" positions[\"rear\"] += np.array([x_v * dt, y_v * dt])\n",
" positions[\"front\"] = positions[\"rear\"] + np.array([l * np.cos(heading), l * np.sin(heading)])\n",
" wheel_pos.append([positions[\"rear\"][0], positions[\"rear\"][1], positions[\"front\"][0], positions[\"front\"][1]])\n",
" data.append(np.array(wheel_pos))\n",
" \n",
" # plotting the data\n",
" col = i % 3\n",
" if i > 0 and i % 3 == 0:\n",
" row += 1\n",
" ax[row, col].scatter(np.array(wheel_pos)[:6, 0], np.array(wheel_pos)[:6, 1], color='#94D0FF', label='past', s=40)\n",
" ax[row, col].scatter(np.array(wheel_pos)[6:, 0], np.array(wheel_pos)[6:, 1], color='#FF6AD5', label='future', s=40)\n",
" ax[row, col].axis(xmin=-15, xmax=15, ymin=-15, ymax=20)\n",
"\n",
"ax[1, 2].legend()\n",
"plt.show()\n",
"data = np.array(data)[:, :, :2]"
],
"execution_count": 2,
"outputs": [
{
"output_type": "display_data",
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 1080x720 with 6 Axes>"
]
},
"metadata": {
"tags": [],
"needs_background": "light"
}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9JaFXTes9i8M"
},
"source": [
"## Creating a pytorch dataloader"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 0
},
"id": "SlVSU4jFSyC-",
"outputId": "50d4c9c3-2b1a-471d-934f-02fb05a50fe3"
},
"source": [
"# Code tested with pytorch 1.6.0\n",
"!pip install -q torch==1.6.0 torchvision==0.7.0"
],
"execution_count": 3,
"outputs": [
{
"output_type": "stream",
"text": [
"\u001b[K |████████████████████████████████| 748.8MB 20kB/s \n",
"\u001b[K |████████████████████████████████| 5.9MB 21.5MB/s \n",
"\u001b[31mERROR: torchtext 0.9.1 has requirement torch==1.8.1, but you'll have torch 1.6.0 which is incompatible.\u001b[0m\n",
"\u001b[?25h"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "P0eRpxAs9i8M"
},
"source": [
"from torch.utils.data import Dataset\n",
"\n",
"\n",
"class NPYFakeDataset(Dataset):\n",
" def __init__(self):\n",
" self.ego_dataset = data\n",
"\n",
" def get_input_output_seqs(self, ego_data):\n",
" # 6 input timesteps, (cyan-colored in the plot above), \n",
" # which are identical across all 6 examples.\n",
" ego_in = ego_data[:6] \n",
"\n",
" # 12 output (to be predicted by the model) timesteps,\n",
" # (pink in the plot above)\n",
" ego_out = ego_data[6:]\n",
" \n",
" return ego_in, ego_out\n",
"\n",
" def __getitem__(self, idx: int):\n",
" ego_data = self.ego_dataset[idx]\n",
" in_ego, out_ego = self.get_input_output_seqs(ego_data)\n",
" return in_ego, out_ego\n",
"\n",
" def __len__(self):\n",
" return len(self.ego_dataset)\n"
],
"execution_count": 4,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "7sg9ohVJ9i8N"
},
"source": [
"## Model Code"
]
},
{
"cell_type": "code",
"metadata": {
"id": "9On9HNKh9i8N"
},
"source": [
"import math\n",
"import numpy as np\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"\n",
"\n",
"def init(module, weight_init, bias_init, gain=1):\n",
" weight_init(module.weight.data, gain=gain)\n",
" bias_init(module.bias.data)\n",
" return module\n",
"\n",
"\n",
"class PositionalEncoding(nn.Module):\n",
" '''\n",
" Sine/cosine positional encoding (standard procedure for transformer sequential inputs)\n",
" '''\n",
"\n",
" def __init__(self, d_model, dropout=0.1, max_len=20):\n",
" super(PositionalEncoding, self).__init__()\n",
" self.dropout = nn.Dropout(p=dropout)\n",
" pe = torch.zeros(max_len, d_model)\n",
" position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)\n",
" div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))\n",
" pe[:, 0::2] = torch.sin(position * div_term)\n",
" pe[:, 1::2] = torch.cos(position * div_term)\n",
" pe = pe.unsqueeze(0).transpose(0, 1)\n",
" self.register_buffer('pe', pe)\n",
"\n",
" def forward(self, x):\n",
" '''\n",
" :param x: must be (T, B, H)\n",
" :return:\n",
" '''\n",
" x = x + self.pe[:x.size(0), :]\n",
" return self.dropout(x)\n",
"\n",
"\n",
"class AutoBotEgo(nn.Module):\n",
" '''\n",
" Sequential Set Transformer model for Small Synthetic Non-linear Particle Accelerator Dataset.\n",
" '''\n",
" def __init__(self, d_k=64, num_modes=3):\n",
" super(AutoBotEgo, self).__init__()\n",
"\n",
" init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.constant_(x, 0), np.sqrt(2))\n",
"\n",
" self.d_k = d_k\n",
" self.num_modes = num_modes\n",
" self.num_heads = 8\n",
"\n",
" self.output_model = OutputModelBVG(d_k=d_k)\n",
"\n",
" tx_encoder_layer = nn.TransformerEncoderLayer(d_model=d_k, nhead=self.num_heads)\n",
" self.tx_encoder = nn.TransformerEncoder(tx_encoder_layer, num_layers=1)\n",
"\n",
" self.emb_pos = init_(nn.Linear(2, d_k))\n",
" \n",
" tx_decoder_layer = nn.TransformerDecoderLayer(d_model=d_k, nhead=self.num_heads)\n",
" self.tx_decoder = nn.TransformerDecoder(tx_decoder_layer, num_layers=1)\n",
"\n",
" self.pos_encoder = PositionalEncoding(d_k, dropout=0.0)\n",
"\n",
" self.emb_intention = nn.Sequential(\n",
" init_(nn.Linear(num_modes, d_k))\n",
" )\n",
" self.emb_posint = nn.Sequential(\n",
" init_(nn.Linear(2*d_k, d_k)), nn.ReLU(),\n",
" init_(nn.Linear(d_k, d_k))\n",
" )\n",
"\n",
" self.mode_parameters = nn.Parameter(torch.Tensor(1, num_modes, d_k))\n",
" nn.init.xavier_uniform_(self.mode_parameters)\n",
" self.prob_decoder = nn.TransformerDecoderLayer(d_model=d_k, nhead=8)\n",
" self.prob_predictor = init_(nn.Linear(d_k, 1))\n",
"\n",
" self.train()\n",
"\n",
" def generate_decoder_mask(self, seq_len, device):\n",
" ''' For masking out the subsequent info. '''\n",
" subsequent_mask = (torch.triu(torch.ones((seq_len, seq_len), device=device), diagonal=1)).bool()\n",
" return subsequent_mask\n",
"\n",
" def forward(self, ego_input_positions, ego_output_positions):\n",
" B = ego_input_positions.size(0)\n",
" horizon = ego_output_positions.size(1)\n",
" \n",
" # Encode all observations\n",
" encoded_obs = self.emb_pos(ego_input_positions).transpose(0, 1)\n",
"\n",
" # Add positional encoding\n",
" encoded_obs = self.pos_encoder(encoded_obs)\n",
"\n",
" # TX on input seqs\n",
" in_memory = self.tx_encoder(encoded_obs)\n",
" mode_probs = self.prob_decoder(self.mode_parameters.repeat(B, 1, 1).transpose(0, 1), in_memory).transpose(0,1)\n",
" mode_probs = F.softmax(self.prob_predictor(mode_probs).squeeze(-1), dim=1)\n",
"\n",
" intentions = torch.eye(self.num_modes).to(device=ego_input_positions.device).unsqueeze(0).repeat(B, 1, 1)\n",
" enc_intentions = self.emb_intention(intentions).view(B*self.num_modes, self.d_k).unsqueeze(0)\n",
" in_memory = in_memory.unsqueeze(2).repeat(1, 1, self.num_modes, 1).view(-1, B * self.num_modes, self.d_k)\n",
"\n",
" pred_obs = [ego_input_positions[:, -1].unsqueeze(1).repeat(1, self.num_modes, 1).view(B * self.num_modes, -1)]\n",
" dec_start_emb = self.emb_pos(torch.stack(pred_obs, dim=0))\n",
" dec_input_emb = dec_start_emb\n",
" for ts in range(horizon): # autoregressive rollout\n",
" T = len(dec_input_emb)\n",
" curr_intentions = enc_intentions.repeat(T, 1, 1)\n",
" out_emb = torch.cat((curr_intentions, dec_input_emb), dim=-1)\n",
" out_emb = self.emb_posint(out_emb)\n",
"\n",
" out_emb = self.pos_encoder(out_emb)\n",
" time_masks = self.generate_decoder_mask(seq_len=T, device=ego_input_positions.device)\n",
" out_seq = self.tx_decoder(out_emb, in_memory, tgt_mask=time_masks)\n",
" dec_input_emb = torch.cat((dec_start_emb, out_seq), dim=0)\n",
"\n",
" out_dists = self.output_model(out_seq).view(horizon, B, self.num_modes, -1).permute(2, 0, 1, 3)\n",
" return out_dists, mode_probs\n",
"\n",
"\n",
"class OutputModelBVG(nn.Module):\n",
" def __init__(self, d_k=64):\n",
" super(OutputModelBVG, self).__init__()\n",
" self.d_k = d_k\n",
" init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.constant_(x, 0), np.sqrt(2))\n",
" self.observation_model = nn.Sequential(\n",
" init_(nn.Linear(d_k, d_k)), nn.ReLU(),\n",
" init_(nn.Linear(d_k, d_k)), nn.ReLU(),\n",
" init_(nn.Linear(d_k, 5))\n",
" )\n",
" self.min_stdev = 0.1 # for stability.\n",
"\n",
" def forward(self, agent_latent_state):\n",
" '''\n",
" :param agent_latent_state: the hidden-state of the ego-agent (B, T, H).\n",
" :return: A tensor with dimension (B, T, 5) where the 5-D vectors correspond \n",
" to the parameters of a bivariate Gaussian.\n",
" '''\n",
" pred_obs = self.observation_model(agent_latent_state)\n",
" x_mean = pred_obs[:, :, 0]\n",
" y_mean = pred_obs[:, :, 1]\n",
" x_sigma = F.softplus(pred_obs[:, :, 2]) + self.min_stdev\n",
" y_sigma = F.softplus(pred_obs[:, :, 3]) + self.min_stdev\n",
" rho = torch.tanh(pred_obs[:, :, 4]) * 0.9 # for stability\n",
" return torch.stack([x_mean, y_mean, x_sigma, y_sigma, rho], dim=2)\n",
"\n"
],
"execution_count": 5,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "hvQG71Au9i8Q"
},
"source": [
"## Utility Functions\n",
"We define some utility functions for plotting circles for the output distributions (mean and variance at each timestep) and for calculating the multimodal loss.\n"
]
},
{
"cell_type": "code",
"metadata": {
"id": "Iwz5gzTE9i8S"
},
"source": [
"import numpy as np\n",
"import torch\n",
"from scipy import special\n",
"import torch.distributions as D\n",
"from torch.distributions import MultivariateNormal\n",
"from matplotlib.patches import Ellipse\n",
"\n",
"\n",
"def _plot_gaussian(dist, ax, color, zorder=0):\n",
" \"\"\"Plots the mean and 2-std ellipse of a given Gaussian\"\"\"\n",
" cov_val = dist[4] * dist[2] * dist[3]\n",
" mean = [dist[0], dist[1]]\n",
" covariance = np.array([[dist[2] ** 2, cov_val], [cov_val, dist[3] ** 2]])\n",
"\n",
" if covariance.ndim == 1:\n",
" covariance = np.diag(covariance)\n",
"\n",
" radius = np.sqrt(5.991) # for 95% confidence interval.\n",
" eigvals, eigvecs = np.linalg.eig(covariance)\n",
" axis = np.sqrt(eigvals) * radius\n",
" slope = eigvecs[1][0] / eigvecs[1][1]\n",
" angle = 180.0 * np.arctan(slope) / np.pi\n",
"\n",
" e = Ellipse(mean, 2 * axis[0], 2 * axis[1], angle=angle, fill=False, color=color, linewidth=1, zorder=zorder, alpha=1.0)\n",
" ax.add_artist(e)\n",
" e.set_clip_box(ax.bbox)\n",
" return ax\n",
"\n",
"\n",
"def get_BVG_distributions(pred):\n",
" '''\n",
" Transform the prediction tensor of dim (B, T, 5) to torch Multivariate Gaussians distributions.\n",
" '''\n",
" B = pred.size(0)\n",
" T = pred.size(1)\n",
" mu_x = pred[:, :, 0].unsqueeze(2)\n",
" mu_y = pred[:, :, 1].unsqueeze(2)\n",
" sigma_x = pred[:, :, 2]\n",
" sigma_y = pred[:, :, 3]\n",
" rho = pred[:, :, 4]\n",
"\n",
" cov = torch.zeros((B, T, 2, 2)).to(pred.device)\n",
" cov[:, :, 0, 0] = sigma_x ** 2\n",
" cov[:, :, 1, 1] = sigma_y ** 2\n",
" cov_val = rho * sigma_x * sigma_y\n",
" cov[:, :, 0, 1] = cov_val\n",
" cov[:, :, 1, 0] = cov_val\n",
"\n",
" biv_gauss_dist = MultivariateNormal(loc=torch.cat((mu_x, mu_y), dim=-1), covariance_matrix=cov)\n",
" return biv_gauss_dist\n",
"\n",
"\n",
"def nll_pytorch_dist(pred, data):\n",
" '''\n",
" Args:\n",
" pred: [B, T, 5]\n",
" data: [B, T, 2]\n",
" This function computes the negative log-likelihood of the data given the predicted distributions.\n",
" Returns the nll vector for all elements in the batch.\n",
" '''\n",
" biv_gauss_dist = get_BVG_distributions(pred)\n",
" loss = -biv_gauss_dist.log_prob(data).sum(1) # sum over all timesteps\n",
" return loss # [B]\n",
"\n",
"\n",
"def nll_loss_multimodes(pred, data, modes_pred, entropy_weight=1.0, val_nll=False, kl_weight=1.0):\n",
" \"\"\"NLL loss multimodes for training. MFP Loss function\n",
" Args:\n",
" pred: [K, T, B, 5]\n",
" data: [B, T, 2]\n",
" modes_pred: [B, K], prior prob over modes\n",
" \"\"\"\n",
" K = len(pred)\n",
" T, B, dim = pred[0].shape\n",
"\n",
" # Here, we compute the log-likelihood of the data given the predicted distributions, p(y|z,x). \n",
" # This part is used in combination with the predicted prior distribution p(z|x) to compute the posterior p(z|y,x).\n",
" log_lik = np.zeros((B, K))\n",
" with torch.no_grad():\n",
" for kk in range(K):\n",
" nll = nll_pytorch_dist(pred[kk].transpose(0, 1), data)\n",
" log_lik[:, kk] = -nll.cpu().numpy()\n",
"\n",
" # The following is an application of Bayes Rule.\n",
" priors = modes_pred.detach().cpu().numpy()\n",
" log_post_unnorm = log_lik + np.log(priors)\n",
" log_post = log_post_unnorm - special.logsumexp(log_post_unnorm, axis=1).reshape((B, 1))\n",
" post_prob = np.exp(log_post)\n",
" post_prob = torch.tensor(post_prob).float().to(data.device)\n",
"\n",
" # Using the computed posterior, we now can compute the data negative loglikelihood exactly.\n",
" loss = 0.0\n",
" for kk in range(K):\n",
" nll_k = nll_pytorch_dist(pred[kk].transpose(0, 1), data) * post_prob[:, kk]\n",
" loss += nll_k.sum() / float(B)\n",
"\n",
" # Compute the KL divergence between p(z|x) and p(z|x,y).\n",
" kl_loss = torch.nn.KLDivLoss(reduction='batchmean')\n",
" loss += kl_weight*kl_loss(torch.log(modes_pred), post_prob)\n",
"\n",
" # The entropy regularization term.\n",
" if not val_nll:\n",
" entropy_vals = []\n",
" for kk in range(K):\n",
" entropy_vals.append(get_BVG_distributions(pred[kk]).entropy())\n",
" entropy_loss = torch.mean(torch.stack(entropy_vals).permute(2, 0, 1).sum(2).max(1)[0])\n",
" loss += entropy_weight*entropy_loss\n",
"\n",
" return loss\n"
],
"execution_count": 6,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "aXJMdKqU9i8U"
},
"source": [
"## Training Loop\n",
"The training loop takes about 10 minutes on a single-GPU machine."
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 0
},
"id": "6mxWEJKd9i8V",
"outputId": "b36fbbc0-6d61-4615-f40f-ea8ba5f07e14"
},
"source": [
"import torch\n",
"from torch import optim\n",
"import torch.distributions as D\n",
"print(torch.__version__)\n",
"\n",
"\n",
"num_modes = 10\n",
"d_k = 64\n",
"learning_rate = 0.0001\n",
"entropy_weight = 5.0 # turn this up/down to see the effect on the variance.\n",
"seed = 0\n",
"np.random.seed(seed)\n",
"\n",
"if torch.cuda.is_available():\n",
" device = torch.device(\"cuda\")\n",
" torch.cuda.manual_seed(seed)\n",
"else:\n",
" device = torch.device(\"cpu\")\n",
"\n",
"# Initialize model\n",
"autobot_model = AutoBotEgo(d_k=d_k, num_modes=num_modes).to(device)\n",
"optimiser = optim.Adam(autobot_model.parameters(), lr=learning_rate, eps=1e-4)\n",
"\n",
"# Initialize dataloader\n",
"train_nuscenes = NPYFakeDataset()\n",
"train_loader = torch.utils.data.DataLoader(train_nuscenes, batch_size=6, shuffle=True, num_workers=3, drop_last=True, pin_memory=True)\n",
"\n",
"total_steps = 0\n",
"losses = []\n",
"for train_iter in range(0, 3000):\n",
" for i, data in enumerate(train_loader):\n",
" ego_in, ego_out = data\n",
" ego_in = ego_in.float().to(device)\n",
" ego_out = ego_out.float().to(device)\n",
"\n",
" # encode observations\n",
" pred_obs, modes_pred = autobot_model(ego_in, ego_out)\n",
"\n",
" # Compute the loss.\n",
" loss = nll_loss_multimodes(pred_obs, ego_out[:, :, :2], modes_pred, entropy_weight=entropy_weight)\n",
"\n",
" # A measure of the entropy of the output distributions.\n",
" sigmas = pred_obs[:, :, :, 2:4]\n",
" sigma_magnitude = torch.mean(torch.norm(sigmas, dim=-1))\n",
"\n",
" optimiser.zero_grad()\n",
" loss.backward()\n",
" torch.nn.utils.clip_grad_norm_(autobot_model.parameters(), 0.5)\n",
" optimiser.step()\n",
"\n",
" # Store (0) observation loss (1) reward loss (2) KL loss\n",
" losses.append(loss.item())\n",
"\n",
" if train_iter % 50 == 0:\n",
" print(train_iter, \"Obs_Loss\", losses[-1], \"Prior Entropy\", torch.mean(D.Categorical(modes_pred).entropy()).item(), \"Sigma Magnitude\", sigma_magnitude.item())\n"
],
"execution_count": 7,
"outputs": [
{
"output_type": "stream",
"text": [
"1.6.0\n",
"0 Obs_Loss 2369.8955078125 Prior Entropy 1.4875106811523438 Sigma Magnitude 0.8969748616218567\n",
"50 Obs_Loss 210.09817504882812 Prior Entropy 2.1081435680389404 Sigma Magnitude 1.4822051525115967\n",
"100 Obs_Loss 157.78390502929688 Prior Entropy 2.1547019481658936 Sigma Magnitude 1.066394567489624\n",
"150 Obs_Loss 130.46844482421875 Prior Entropy 2.1374456882476807 Sigma Magnitude 0.9238507747650146\n",
"200 Obs_Loss 93.3446044921875 Prior Entropy 2.1341917514801025 Sigma Magnitude 0.7530378699302673\n",
"250 Obs_Loss 72.35761260986328 Prior Entropy 2.1071135997772217 Sigma Magnitude 0.6062721014022827\n",
"300 Obs_Loss 80.39238739013672 Prior Entropy 2.0682501792907715 Sigma Magnitude 0.5384072065353394\n",
"350 Obs_Loss 51.74512481689453 Prior Entropy 2.0305049419403076 Sigma Magnitude 0.47288694977760315\n",
"400 Obs_Loss 32.685508728027344 Prior Entropy 2.011359214782715 Sigma Magnitude 0.3868240714073181\n",
"450 Obs_Loss 11.433774948120117 Prior Entropy 2.0145623683929443 Sigma Magnitude 0.3831108808517456\n",
"500 Obs_Loss 9.873607635498047 Prior Entropy 1.97100031375885 Sigma Magnitude 0.33936724066734314\n",
"550 Obs_Loss 8.771446228027344 Prior Entropy 2.0121748447418213 Sigma Magnitude 0.31136900186538696\n",
"600 Obs_Loss -19.142765045166016 Prior Entropy 1.9727678298950195 Sigma Magnitude 0.28895100951194763\n",
"650 Obs_Loss -13.80767822265625 Prior Entropy 1.9760602712631226 Sigma Magnitude 0.2773299515247345\n",
"700 Obs_Loss -34.384971618652344 Prior Entropy 1.9884058237075806 Sigma Magnitude 0.2564449906349182\n",
"750 Obs_Loss -39.28346252441406 Prior Entropy 1.9636551141738892 Sigma Magnitude 0.24932558834552765\n",
"800 Obs_Loss -41.30181121826172 Prior Entropy 1.987878441810608 Sigma Magnitude 0.23153577744960785\n",
"850 Obs_Loss -48.81230926513672 Prior Entropy 1.9718965291976929 Sigma Magnitude 0.22708381712436676\n",
"900 Obs_Loss -61.86613845825195 Prior Entropy 1.983053207397461 Sigma Magnitude 0.2147447019815445\n",
"950 Obs_Loss -59.352760314941406 Prior Entropy 1.9651894569396973 Sigma Magnitude 0.2024281769990921\n",
"1000 Obs_Loss -57.75294876098633 Prior Entropy 1.949167251586914 Sigma Magnitude 0.20529471337795258\n",
"1050 Obs_Loss -62.86778259277344 Prior Entropy 1.961978793144226 Sigma Magnitude 0.20152588188648224\n",
"1100 Obs_Loss -73.48599243164062 Prior Entropy 1.9676356315612793 Sigma Magnitude 0.19365176558494568\n",
"1150 Obs_Loss -69.45549774169922 Prior Entropy 1.951179027557373 Sigma Magnitude 0.18875083327293396\n",
"1200 Obs_Loss -65.29815673828125 Prior Entropy 1.9962520599365234 Sigma Magnitude 0.18455862998962402\n",
"1250 Obs_Loss -74.91232299804688 Prior Entropy 1.9666643142700195 Sigma Magnitude 0.1774861216545105\n",
"1300 Obs_Loss -82.59066772460938 Prior Entropy 1.9747203588485718 Sigma Magnitude 0.17883709073066711\n",
"1350 Obs_Loss -83.88932037353516 Prior Entropy 1.95893132686615 Sigma Magnitude 0.17508520185947418\n",
"1400 Obs_Loss -94.74249267578125 Prior Entropy 1.9619406461715698 Sigma Magnitude 0.17152521014213562\n",
"1450 Obs_Loss -83.56954956054688 Prior Entropy 1.9574122428894043 Sigma Magnitude 0.1688963770866394\n",
"1500 Obs_Loss -95.97639465332031 Prior Entropy 1.934739589691162 Sigma Magnitude 0.16633732616901398\n",
"1550 Obs_Loss -84.9005126953125 Prior Entropy 1.952906608581543 Sigma Magnitude 0.167943075299263\n",
"1600 Obs_Loss -93.68824768066406 Prior Entropy 1.9422146081924438 Sigma Magnitude 0.16340167820453644\n",
"1650 Obs_Loss -93.48336029052734 Prior Entropy 1.9700676202774048 Sigma Magnitude 0.1643861085176468\n",
"1700 Obs_Loss -93.68887329101562 Prior Entropy 1.947595238685608 Sigma Magnitude 0.16277040541172028\n",
"1750 Obs_Loss -106.33206176757812 Prior Entropy 1.9439034461975098 Sigma Magnitude 0.1631186455488205\n",
"1800 Obs_Loss -114.87834167480469 Prior Entropy 1.9303315877914429 Sigma Magnitude 0.16260327398777008\n",
"1850 Obs_Loss -111.52286529541016 Prior Entropy 1.9322032928466797 Sigma Magnitude 0.16073709726333618\n",
"1900 Obs_Loss -99.67212677001953 Prior Entropy 1.954646110534668 Sigma Magnitude 0.1583443582057953\n",
"1950 Obs_Loss -112.62327575683594 Prior Entropy 1.9359534978866577 Sigma Magnitude 0.15894180536270142\n",
"2000 Obs_Loss -112.17411041259766 Prior Entropy 1.9476476907730103 Sigma Magnitude 0.15647853910923004\n",
"2050 Obs_Loss -117.08065032958984 Prior Entropy 1.9370139837265015 Sigma Magnitude 0.15652796626091003\n",
"2100 Obs_Loss -116.01622772216797 Prior Entropy 1.943585753440857 Sigma Magnitude 0.15343520045280457\n",
"2150 Obs_Loss -104.03912353515625 Prior Entropy 1.920638918876648 Sigma Magnitude 0.15452148020267487\n",
"2200 Obs_Loss -118.58104705810547 Prior Entropy 1.945374846458435 Sigma Magnitude 0.15166401863098145\n",
"2250 Obs_Loss -127.10646057128906 Prior Entropy 1.934796929359436 Sigma Magnitude 0.15318652987480164\n",
"2300 Obs_Loss -129.23353576660156 Prior Entropy 1.9234099388122559 Sigma Magnitude 0.15286634862422943\n",
"2350 Obs_Loss -118.142333984375 Prior Entropy 1.9349241256713867 Sigma Magnitude 0.15083317458629608\n",
"2400 Obs_Loss -122.60208892822266 Prior Entropy 1.9366976022720337 Sigma Magnitude 0.1516687273979187\n",
"2450 Obs_Loss -132.3092498779297 Prior Entropy 1.9179943799972534 Sigma Magnitude 0.1510191559791565\n",
"2500 Obs_Loss -119.2080078125 Prior Entropy 1.937208652496338 Sigma Magnitude 0.15157970786094666\n",
"2550 Obs_Loss -126.15424346923828 Prior Entropy 1.913926601409912 Sigma Magnitude 0.1499032974243164\n",
"2600 Obs_Loss -127.50374603271484 Prior Entropy 1.9175949096679688 Sigma Magnitude 0.14963364601135254\n",
"2650 Obs_Loss -127.7978744506836 Prior Entropy 1.911878228187561 Sigma Magnitude 0.1486077606678009\n",
"2700 Obs_Loss -130.70504760742188 Prior Entropy 1.90842866897583 Sigma Magnitude 0.15001989901065826\n",
"2750 Obs_Loss -134.3357696533203 Prior Entropy 1.9386142492294312 Sigma Magnitude 0.1491604447364807\n",
"2800 Obs_Loss -137.72341918945312 Prior Entropy 1.9272490739822388 Sigma Magnitude 0.14933128654956818\n",
"2850 Obs_Loss -140.33743286132812 Prior Entropy 1.933469295501709 Sigma Magnitude 0.14862534403800964\n",
"2900 Obs_Loss -133.574951171875 Prior Entropy 1.9343080520629883 Sigma Magnitude 0.14803217351436615\n",
"2950 Obs_Loss -135.62042236328125 Prior Entropy 1.9297490119934082 Sigma Magnitude 0.14848901331424713\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Tw0HaOtk9i8W"
},
"source": [
"## Testing Mode Learning on Toy Dataset\n",
"\n",
"This consistutes the results shown in Figure 2 (bottom row) of the paper. To get the results corresponding to the middle row, reduce the `entropy_weight` in the training block above and repeat training.\n"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 596
},
"id": "9o6cRo6y9i8W",
"outputId": "d1b20cec-7307-4486-ac96-36f6f0464730"
},
"source": [
"autobot_model.eval()\n",
"with torch.no_grad():\n",
" for i, data in enumerate(train_loader):\n",
" ego_in, ego_out = data\n",
" ego_in = ego_in.float().to(device)\n",
" ego_out = ego_out.float().to(device)\n",
"\n",
" pred_obs, mode_preds = autobot_model(ego_in, ego_out)\n",
" pred_positions = pred_obs[:, :, 0, :2].squeeze().cpu().numpy()\n",
" mode_probs_np = mode_preds[0].squeeze().cpu().numpy()\n",
" pred_distributions = pred_obs[:, :, 0].squeeze().cpu().numpy()\n",
"\n",
" top_6_modes = mode_probs_np.argsort()[-6:][::-1]\n",
" fig, ax = plt.subplots(nrows=2, ncols=5, figsize=(15, 10))\n",
" row = 0\n",
" for k_idx in range(num_modes):\n",
" col = k_idx % 5\n",
" if k_idx > 0 and k_idx % 5 == 0:\n",
" row += 1\n",
" k = k_idx\n",
" ax[row, col].scatter(pred_positions[k, :, 0], pred_positions[k, :, 1], s=10, color='k')\n",
"\n",
" for t in range(12):\n",
" ax[row, col] = _plot_gaussian(pred_distributions[k, t], ax[row, col], color='#966BFF')\n",
" ax[row, col].axis(xmin=-15, xmax=15, ymin=-15, ymax=20)\n",
"\n",
" plt.show()\n"
],
"execution_count": 8,
"outputs": [
{
"output_type": "display_data",
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 1080x720 with 10 Axes>"
]
},
"metadata": {
"tags": [],
"needs_background": "light"
}
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment