-
-
Save yulkang/38795aad2c519911556974e36849deba to your computer and use it in GitHub Desktop.
Wrap PyTorch functions for scipy's optimize.minimize: https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize.html
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
*.pyc |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
import torch.nn.functional as F | |
import numpy as np | |
from scipy import optimize | |
from pytorchobjective.obj_torch import PyTorchObjective | |
from tqdm import tqdm | |
if __name__ == '__main__': | |
# whatever this initialises to is our "true" W | |
linear = nn.Linear(32,32) | |
linear = linear.eval() | |
# input X | |
N = 10000 | |
X = torch.Tensor(N,32) | |
X.uniform_(0.,1.) # fill with uniform | |
eps = torch.Tensor(N,32) | |
eps.normal_(0., 1e-4) | |
# output Y | |
with torch.no_grad(): | |
Y = linear(X) #+ eps | |
# make module executing the experiment | |
class Objective(nn.Module): | |
def __init__(self): | |
super(Objective, self).__init__() | |
self.linear = nn.Linear(32,32) | |
self.linear = self.linear.train() | |
self.X, self.Y = X, Y | |
def forward(self): | |
output = self.linear(self.X) | |
return F.mse_loss(output, self.Y).mean() | |
objective = Objective() | |
maxiter = 100 | |
with tqdm(total=maxiter) as pbar: | |
def verbose(xk): | |
pbar.update(1) | |
# try to optimize that function with scipy | |
obj = PyTorchObjective(objective) | |
xL = optimize.minimize(obj.fun, obj.x0, method='BFGS', jac=obj.jac, | |
callback=verbose, options={'gtol': 1e-6, 'disp': True, | |
'maxiter':maxiter}) | |
#xL = optimize.minimize(obj.fun, obj.x0, method='CG', jac=obj.jac)# , options={'gtol': 1e-2}) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
From https://gist.github.com/gngdb/a9f912df362a85b37c730154ef3c294b | |
2021-01 YK added parameters() | |
""" | |
import torch | |
from scipy import optimize | |
import torch.nn.functional as F | |
import math | |
import numpy as np | |
from functools import reduce | |
from collections import OrderedDict | |
class PyTorchObjective(object): | |
"""PyTorch objective function, wrapped to be called by scipy.optimize.""" | |
def __init__(self, obj_module, separate_loss_for_jac=False): | |
""" | |
:param obj_module: | |
:param separate_loss_for_jac: if True, obj_module.forward() returns | |
two separate losses, first for gradient computation, second for | |
the loss itself. Used, e.g., for REINFORCE. | |
obj_for_jac, obj = obj_module() | |
""" | |
self.f = obj_module # some pytorch module, that produces a scalar loss | |
# make an x0 from the parameters in this module | |
self.x0 = self.parameters() | |
self.separate_loss_for_jac = separate_loss_for_jac | |
def parameters(self) -> np.ndarray: | |
parameters = OrderedDict(self.f.named_parameters()) | |
self.param_shapes = {n: parameters[n].size() for n in parameters} | |
# ravel and concatenate all parameters to make x0 | |
return np.concatenate([parameters[n].data.numpy().ravel() | |
for n in parameters]) | |
def unpack_parameters(self, x): | |
"""optimize.minimize will supply 1D array, chop it up for each parameter.""" | |
i = 0 | |
named_parameters = OrderedDict() | |
for n in self.param_shapes: | |
param_len = reduce(lambda x,y: x*y, self.param_shapes[n]) | |
# slice out a section of this length | |
param = x[i:i+param_len] | |
# reshape according to this size, and cast to torch | |
param = param.reshape(*self.param_shapes[n]) | |
named_parameters[n] = torch.from_numpy(param) | |
# update index | |
i += param_len | |
return named_parameters | |
def pack_grads(self): | |
"""pack all the gradients from the parameters in the module into a | |
numpy array.""" | |
grads = [] | |
for i, p in enumerate(self.f.parameters()): | |
# print(i) # CHECKED | |
grad = p.grad.data.numpy() | |
grads.append(grad.ravel()) | |
return np.concatenate(grads) | |
def is_new(self, x): | |
# if this is the first thing we've seen | |
if not hasattr(self, 'cached_x'): | |
return True | |
else: | |
# compare x to cached_x to determine if we've been given a new input | |
x, self.cached_x = np.array(x), np.array(self.cached_x) | |
error = np.abs(x - self.cached_x) | |
return error.max() > 1e-8 | |
def cache(self, x): | |
# unpack x and load into module | |
state_dict = self.unpack_parameters(x) | |
self.f.load_state_dict(state_dict) | |
# store the raw array as well | |
self.cached_x = x | |
# zero the gradient | |
self.f.zero_grad() | |
# use it to calculate the objective | |
if self.separate_loss_for_jac: | |
obj_jac, obj = self.f() | |
# backprop the objective | |
obj_jac.backward() | |
self.cached_f = obj.item() | |
else: | |
obj = self.f() | |
# backprop the objective | |
obj.backward() | |
self.cached_f = obj.item() | |
self.cached_jac = self.pack_grads() | |
def fun(self, x): | |
if self.is_new(x): | |
self.cache(x) | |
return self.cached_f | |
def jac(self, x): | |
if self.is_new(x): | |
self.cache(x) | |
return self.cached_jac |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment