Last active
July 10, 2019 15:28
-
-
Save jdiez17/6927c3ece844ebbe881e1294b38b6d7e to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import defaultdict, OrderedDict | |
from scipy.integrate import solve_ivp | |
import networkx as nx | |
import numpy as np | |
import pdb | |
class Node: | |
def __init__(self, *args, **kwargs): | |
self._values = {} | |
self._connections = defaultdict(list) | |
# Initialize all values to None | |
for k, v in vars(self.__class__).items(): | |
if not isinstance(v, Value): | |
continue | |
setattr(self, k, None) | |
# Override any values if they are passed as kwargs | |
for k, v in kwargs.items(): | |
if k not in self._values: | |
# TODO raise somethign here | |
continue | |
setattr(self, k, v) | |
def solve(self, *args, **kwargs): | |
raise NotImplementedError("Function solve() not implemented for Node class {}".format(self.__class__.__name__)) | |
class Value: | |
def __init__(self, initval=0): | |
self.initval = initval | |
def __set_name__(self, inst, name): | |
self.name = name | |
def __get__(self, inst, objtype=None): | |
return inst.__dict__['_values'][self.name] | |
def __set__(self, inst, val): | |
inst.__dict__['_values'][self.name] = val | |
connections = inst._connections[self.name] | |
for target, target_val in connections: | |
setattr(target, target_val, val) | |
class Input(Value): | |
pass | |
class Output(Value): | |
pass | |
class State(Output): | |
pass | |
class NodeGraph: | |
def __init__(self, root_node): | |
self.G = nx.DiGraph() | |
self._root_node = root_node | |
def get_execution_order(self): | |
execution_order = [] | |
for edge in nx.bfs_edges(self.G, self._root_node): | |
for node in edge: | |
if node not in execution_order: | |
execution_order.append(node) | |
return execution_order | |
def connect(self, obj, prop, target): | |
self.G.add_edge(obj, target, label=prop) | |
obj._connections[prop].append((target, prop)) | |
def get_human_readable_graph(self): | |
nodes = self.G.nodes() | |
mapping = {} | |
for node in nodes: | |
mapping[node] = node.__class__.__name__ | |
# TODO use record-based nodes https://www.graphviz.org/doc/info/shapes.html#record | |
g = nx.relabel_nodes(self.G, mapping, copy=True) | |
p = nx.nx_pydot.to_pydot(g) | |
p.set_rankdir("LR") | |
return p | |
class Solver: | |
def __init__(self, *nodes): | |
self._execution_order = [] | |
self._state_lengths = [] | |
self.state_map = OrderedDict() | |
self.node_graph = NodeGraph(nodes[0]) # First node is considered to be the root node | |
for node in nodes: | |
self.add(node) | |
def connect(self, *args, **kwargs): | |
return self.node_graph.connect(*args, **kwargs) | |
def add(self, model): | |
if model not in self.state_map: | |
self.state_map[model] = [] | |
for k, v in vars(model.__class__).items(): | |
if isinstance(v, State): | |
self.state_map[model].append(k) | |
def _rhs(self, t, x): | |
# Unpack states, put them into the model's attributes | |
cnt = 0 | |
state_idx = 0 | |
for model in self._execution_order: | |
for state in self.state_map[model]: | |
state_length = self._state_lengths[state_idx] | |
state_idx += 1 | |
state_value = x[cnt:cnt+state_length] | |
cnt += state_length | |
#print("setting", model, state, "=", state_value) | |
setattr(model, state, state_value) | |
# Run the diff eqs for each model | |
diffs = [] | |
for model in self._execution_order: | |
results = model.solve(t) | |
#print("res", results) | |
if len(self.state_map[model]) == 0: | |
# Model has no states, so don't include its changes over time | |
continue | |
# TODO type checking here | |
diffs.extend(results) | |
return np.hstack(diffs) | |
def solve(self, start, end): | |
self._execution_order = self.node_graph.get_execution_order() | |
self._state_lengths = [] | |
# First, gather all current states | |
all_states = [] | |
for model in self._execution_order: | |
for state in self.state_map[model]: | |
model_state = getattr(model, state) | |
# Figure out how many entries in the `all_states` array this state will occupy. | |
try: | |
# If it's a list of some sort, just take its `len` | |
length = len(model_state) | |
except TypeError: | |
# Single values take 1 entry | |
length = 1 | |
self._state_lengths.append(length) | |
all_states.append(model_state) | |
all_states = np.hstack(all_states) | |
return solve_ivp(self._rhs, (start, end), all_states, rtol=1e-9) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from lib import Node, Input, Output, State, Solver | |
import matplotlib.pyplot as plt | |
import networkx as nx | |
import numpy as np | |
from astropy.constants import GM_earth | |
class Orbit(Node): | |
r = State() | |
v = State() | |
def solve(self, t): | |
return [ | |
self.v, | |
-GM_earth.value * self.r / np.linalg.norm(self.r) ** 3 | |
] | |
class MagneticField(Node): | |
r = Input() | |
def solve(self, t): | |
pass | |
#loc = coord.EarthLocation.from_geocentric(self.r[0], self.r[1], self.r[2], unit=u.m) | |
#print(loc.lat.value, loc.lon.value, loc.height) | |
class SunModel(Node): | |
S = Output() | |
F = Output() | |
def solve(self, t): | |
pass | |
class Eclipse(Node): | |
r = Input() | |
S = Input() | |
O = Output() | |
def solve(self, t): | |
pass | |
if __name__ == '__main__': | |
r = 7018136.30000 | |
v = np.sqrt(GM_earth.value / r) | |
orbit = Orbit( | |
r=np.array([r, 0, 0]), | |
v=np.array([0, v, 0]) | |
) | |
mf = MagneticField() | |
sun = SunModel() | |
e = Eclipse() | |
solver = Solver(orbit, mf, sun, e) | |
solver.connect(orbit, 'r', mf) | |
solver.connect(orbit, 'r', e) | |
solver.connect(sun, 'S', e) | |
res = solver.solve(0, 3600) | |
print(res) | |
plt.figure() | |
plt.plot(res.y[0, :], res.y[1, :]) | |
plt.axis('equal') | |
plt.show() | |
p = solver.node_graph.get_human_readable_graph() | |
p.write_png("m1-pydot.png") | |
print(p) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment