Last active
April 12, 2023 08:54
-
-
Save liyu1981/caf28bd5d34343d2759233597962b122 to your computer and use it in GitHub Desktop.
A simple function to animate the training process in google colab with PyTorch
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Usage: | |
# just copy this function into your google colab notebook, then use it | |
# a working example can be seen in this demo notebook: | |
# https://colab.research.google.com/drive/19Ni0EfOExQmTcrFZh3Z6DxsOM25KULqN?usp=sharing | |
import matplotlib.pyplot as plt | |
from matplotlib import animation, rc | |
def make_animate_train(train_step_fn, animate_setup_fn): | |
""" | |
return a train function which returns animation of training visulization. The | |
function returned has spec: | |
def train(net, x, y, optimizer, loss_func, iterations=100) | |
and it will return an animation object can be shown in colab notebook | |
:param train_step_fn a function with spec: | |
def train_step(i, plot_state, net, x, y, optimizer, loss_func), and | |
Return: | |
tuple (loss_value, plot_state). | |
Params: | |
i - int value of current iteration number | |
plot_state - tuple of (ax, plot_elem1, plot_elem2, ...), e.g., (ax, line) | |
x - torch tensor, data | |
y - torch tensor, result | |
optimizer - torch optimizer | |
loss_func - torch loss func | |
:param animate_setup_fn a function with spec: | |
def animate_setup(ax) | |
Return: | |
tuple (ax, plot_elem1, plot_elem2, ...), e.g., (ax, line) | |
Params: | |
ax - instance of Axie returned by plt.subplots() | |
""" | |
def train(net, x, y, optimizer, loss_func, iterations=100): | |
fig, ax = plt.subplots() | |
plt.close() | |
plot_state = animate_setup_fn(ax) | |
def set_plot_state(s): | |
plot_state = s | |
def anim_init(): | |
return plot_state[1:] | |
def anim_frame(i): | |
# train model when genearte each frame | |
loss, updated_plot_state = train_step_fn(i, plot_state, net, x, y, optimizer, loss_func) | |
print("\riteration %d done with loss=%f." % (i, loss), end='') | |
set_plot_state(updated_plot_state) | |
return updated_plot_state[1:] | |
anim = animation.FuncAnimation(fig, anim_frame, init_func=anim_init, | |
frames=iterations, interval=100, blit=True) | |
# below is the part which makes it work on Colab | |
rc('animation', html='jshtml') | |
return anim | |
return train |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
or see this notebook: https://colab.research.google.com/drive/19Ni0EfOExQmTcrFZh3Z6DxsOM25KULqN?usp=sharing