Skip to content

Instantly share code, notes, and snippets.

@sile
Last active April 8, 2020 08:55
Show Gist options
  • Save sile/1158ba37ff5b8c290f8953acebffed80 to your computer and use it in GitHub Desktop.
Save sile/1158ba37ff5b8c290f8953acebffed80 to your computer and use it in GitHub Desktop.
Example: gin-config with optuna
train.batch_size = 10
train.learning_rate = 0.1
import re
import optuna
import subprocess
import tempfile
def objective(trial):
# Suggest parameters.
trial.suggest_int("train.batch_size", 4, 100)
trial.suggest_loguniform("train.learning_rate", 0.0001, 1.0)
# Replace config entries with the suggested parameters.
config = open("config.gin").read()
for name, value in trial.params.items():
config = re.sub("(?<=" + name + ") *=.*", "=" + str(value), config) # FIXME: Escape `name`
# Create a temporary config file.
temp = tempfile.NamedTemporaryFile(mode="w", encoding="utf-8")
temp.write(config)
temp.flush()
# Run train script with the temporary config.
result = subprocess.run(
["python3", "train.py", "--config-path", temp.name],
stdout=subprocess.PIPE,
encoding="utf-8",
)
# Parse the script output to get the objective value.
return float(result.stdout)
study = optuna.create_study()
study.optimize(objective, n_trials=10)
import argparse
import gin
@gin.configurable
def train(batch_size, learning_rate):
value = batch_size * learning_rate # TODO: Replace with a real training code.
print("{}".format(value))
parser = argparse.ArgumentParser()
parser.add_argument("--config-path", default="config.gin")
args = parser.parse_args()
gin.parse_config_file(args.config_path)
train()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment