Created
July 4, 2015 03:09
-
-
Save smackesey/633053c3408e266f2dd0 to your computer and use it in GitHub Desktop.
spearmint_shunt
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
import copy | |
import json | |
import numpy as np | |
import os | |
import os.path as osp | |
import sys | |
from spearmint.tasks.task_group import TaskGroup | |
from spearmint.choosers.default_chooser import DefaultChooser | |
# Spearmint line num refs for: | |
# rev e34b19 | |
################################### | |
########### SHUNT | |
################################### | |
class Experiment(object): | |
def __init__(self, name=None, description=None, parameters=None, | |
outcome=None, path=None): | |
self.path = path | |
# create empty database if it doesn't exist | |
if not osp.exists(self.path): | |
with open(self.path, 'w') as f: | |
f.write(json.dumps({})) | |
# read the database | |
with open(self.path, 'r') as f: | |
self.db = json.load(f) | |
# load the specified experiment from the DB | |
if not self.db.has_key(name): | |
self.db[name] = { 'jobs': [], 'hypers': {}, | |
'parameters': parameters, 'outcome': outcome['name'] } | |
self.store = self.db[name] | |
def suggest(self): | |
params = self.spearmintify_parameters(self.store['parameters']) | |
self.chooser = DefaultChooser({}) | |
# see spearmint/main.py line 231-232; this is the default value if no tasks are provided in config | |
tasks = {'main' : {'type' : 'OBJECTIVE', 'likelihood' : 'GAUSSIAN'} } | |
task_group = TaskGroup(tasks, params) | |
sjobs = [ self.spearmintify_job(job) for job in self.store['jobs'] ] | |
task_group.inputs = np.array([task_group.vectorify(job['params']) for job in sjobs]) | |
task_group.values = {'main' : np.array([-job[self.store['outcome']] for job in sjobs])} | |
task_group.add_nan_task_if_nans() | |
shypers = self.spearmintify_hypers(self.store['hypers']) | |
shypers = self.chooser.fit(task_group, shypers, tasks) | |
self.store['hypers'] = self.dbify_hypers(shypers) | |
params = task_group.paramify(self.chooser.suggest()) | |
return { k: v['values'][0] for k,v in params.items() } | |
def update(self, job, obj): | |
# this is a representation of the job that includes the keys we need for spearmint | |
store_job = copy.copy(job) | |
store_job[self.store['outcome']] = obj | |
self.store['jobs'].append(store_job) | |
with open(self.path, 'w') as f: | |
f.write(json.dumps(self.db)) | |
def best(self): | |
return max(self.store['jobs'], key=lambda j: j[self.store['outcome']]) | |
########### TRANSLATE BETWEEN SPEARMINT AND WHETLAB STYLE FORMATS | |
def spearmintify_parameters(self, params): | |
params = copy.deepcopy(params) | |
for k,v in params.items(): | |
v.setdefault('size', 1) | |
v['type'] = self.spearmintify_type(v['type']) | |
return params | |
def spearmintify_type(self, tname): | |
if tname.lower() == 'integer': | |
return 'int' | |
else: | |
return tname.lower() | |
def spearmintify_job(self, job): | |
sjob = { 'params': {}, 'values': {} } | |
for k,v in job.items(): | |
if k in self.store['parameters'].keys(): | |
sjob['params'][k] = { | |
'type': self.spearmintify_type(self.store['parameters'][k]['type']), | |
'values': np.array(v) } | |
elif k == self.store['outcome']: | |
sjob[k] = v | |
return sjob | |
def spearmintify_hypers(self, hypers): | |
if not hypers.has_key('main'): | |
return hypers | |
shypers = copy.deepcopy(hypers) | |
for k in ['beta_alpha', 'beta_beta', 'ls']: | |
if k in hypers['main']['hypers']: | |
shypers['main']['hypers'][k] = np.array(hypers['main']['hypers'][k]) | |
return shypers | |
def dbify_hypers(self, hypers): | |
if not hypers.has_key('main'): | |
return hypers | |
dhypers = copy.deepcopy(hypers) | |
for k in ['beta_alpha', 'beta_beta', 'ls']: | |
if k in hypers['main']['hypers']: | |
dhypers['main']['hypers'][k] = list(np.array(hypers['main']['hypers'][k])) | |
return dhypers | |
################################### | |
########### TEST | |
################################### | |
from sklearn import svm | |
from sklearn.datasets import fetch_mldata | |
data_set = fetch_mldata('yahoo-web-directory-topics') | |
train_set = (data_set['data'][:1000],data_set['target'][:1000]) | |
validation_set = (data_set['data'][1000:],data_set['target'][1000:]) | |
# modify as desired | |
db_path = osp.join(os.getcwd(), 'db.json') | |
parameters = { 'C':{'min':0.01, 'max':1000.0,'type':'float'}, | |
'degree':{'min':1, 'max':5,'type':'integer'}} | |
outcome = {'name':'Classification accuracy'} | |
scientist = Experiment(name="Web page classifier", | |
description="Training a polynomial kernel SVM to classify web pages", | |
parameters=parameters, | |
outcome=outcome, | |
path=db_path) | |
n_iterations = 19 | |
for i in range(n_iterations): | |
job = scientist.suggest() | |
learner = svm.SVC(kernel='poly',**job) | |
learner.fit(*train_set) | |
accuracy = learner.score(*validation_set) | |
scientist.update(job,accuracy) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment