This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Annotated, Union | |
import numpy as np | |
from pydantic import ( | |
AfterValidator, | |
BaseModel, | |
) | |
FloatInfAsNone = Annotated[ | |
Union[float, None], AfterValidator(lambda x: None if x == np.inf else x) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/bin/bash | |
# | |
# General Usage: | |
# ./rsync_script.sh [additional rsync options] <source> <destination> | |
# Example with Additional Options: | |
# ./rsync_script.sh --bwlimit=1000 --info=progress2 /path/to/source user@destination:/path/to/destination | |
# Check if sufficient arguments are provided | |
if [ $# -lt 2 ]; then |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# 1. Reconstruction loss: x0 = x0_recon | |
reconstruction_loss = tf.keras.losses.mean_squared_error(x0, x0_reconstructed) | |
model.add_loss(reconstruction_loss) | |
# 2. Future state prediction loss: x1 = x1_pred | |
state_prediction_loss = tf.keras.losses.mean_squared_error(x1, x1_pred) | |
model.add_loss(state_prediction_loss) | |
# 3. Linear dynamics loss: y1 = K * y0 | |
linear_dynamics_loss = tf.keras.losses.mean_squared_error(code1, code1_pred) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# x_train, y_train = ... | |
model.compile(loss=tf.keras.losses.CategoricalCrossentropy()) | |
model.fit(x=x_train, y=y_train) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Add onto our existing x0/code0 computation graph | |
# x1 is the input x_{k+1} | |
x1 = tf.keras.Input(shape=INPUT_DIM, name='x1') | |
# we can re-use the same sub-models (encoder, decoder, linear_dynamics) | |
# to share the same weights | |
code1 = encoder(x1) | |
x1_pred = decoder(code1_pred) | |
model = tf.keras.Model( | |
inputs={'x0': x0, 'x1': x1}, |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Connect sub-models | |
x = tf.keras.Input(shape=INPUT_DIM) | |
code = encoder(x) | |
x_reconstructed = decoder(code) | |
next_code_pred = linear_dynamics(code) | |
model = tf.keras.Model( | |
inputs=x, | |
outputs={'x0_reconstructed': x_reconstructed, | |
'y1': next_code_pred} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
# Hyperparameters | |
INPUT_DIM = 2 # Same as decoder output dimension | |
HIDDEN_DIM = 30 | |
LATENT_DIM = 2 | |
# Encodes input to low-dimensional code | |
encoder = tf.keras.Sequential( | |
[ |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
pr <- function(...) { | |
# prints variable names and values | |
dots <- substitute(list(...))[-1] # remove 1st element, "list" | |
print(sapply(dots, deparse)) # names | |
print(paste(list(...))) # values | |
} | |
# demo | |
a <- 1 | |
b <- c(3,5,2) |