Created
February 13, 2024 22:57
-
-
Save nlitsme/96a9e1515c90a5c72fd21c0ec7b1e15e to your computer and use it in GitHub Desktop.
symbolic solver
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/python3 | |
import sympy | |
import re | |
import keyword | |
""" | |
symbolic expression tool | |
- expression with '==' -> solves for all free variables, or for the var specified with '-x' | |
- expr without '==' -> simplify | |
- assignment: store resulting expr in var | |
- variables can be letter, or letter_number | |
- specify '-s' to all equations as a system of equations. | |
""" | |
class Symbols: | |
def __init__(self, state, freevariables): | |
self.solvable = list() | |
self.syms = dict() | |
self.freevariables = freevariables | |
# the only single letter constant. | |
self.syms['e'] = sympy.exp(1) | |
self.syms['pi'] = 4*sympy.atan(1) | |
# init from global state | |
for k, v in state.items(): | |
self.syms[k] = v | |
def addexpr(self, txte): | |
# check all function or variable names | |
for m in re.finditer(r'(?<!\.)\b([a-zA-Z]\w*)\b(\()?', txte): | |
name, bracket = m.groups() | |
if name in self.syms: continue | |
# extract symbols: letter + number or letter + '_' + number | |
if bracket: | |
# function | |
fn = getattr(sympy, name) | |
if not fn: | |
print(f"could not find {name}") | |
else: | |
self.syms[name] = fn | |
else: | |
self.syms[name] = sympy.symbols(name) | |
if not name in self.freevariables: | |
self.solvable.append(name) | |
def evalsystem(state, args, txtexprs): | |
s = Symbols(state, args.freevariables) | |
convertedexpr = [] | |
for txte in txtexprs: | |
txte = txte.strip() | |
# is this en equality?? -> solve | |
if m := re.match(r'(.*)==(.*)', txte): | |
txte=f"({m[1]})-({m[2]})" | |
# avoid python eval problems | |
for kw in keyword.kwlist: | |
txte = re.sub(fr'\b{kw}\b', lambda m:m[0]+'_', txte) | |
s.addexpr(txte) | |
convertedexpr.append(eval(txte, s.syms)) | |
if args.numeric: | |
r = sympy.nsolve(convertedexpr, s.solvable, (1,)*len(s.solvable)) | |
else: | |
r = sympy.solve(convertedexpr, s.solvable) | |
print(r) | |
def evalexpr(state, args, txte): | |
txte = txte.strip() | |
exprname = None | |
haseq = False | |
# is this en equality?? -> solve | |
if m := re.match(r'(.*)==(.*)', txte): | |
txte=f"({m[1]})-({m[2]})" | |
haseq = True | |
# is this an assignment?? -> remember in state. | |
elif m := re.match(r'(\w+)\s*=\s*(.*)', txte): | |
exprname = m[1] | |
txte=f"({m[2]})" | |
s = Symbols(state, args.freevariables) | |
# avoid python eval problems | |
for kw in keyword.kwlist: | |
txte = re.sub(fr'\b{kw}\b', lambda m:m[0]+'_', txte) | |
s.addexpr(txte) | |
expr = eval(txte, s.syms) | |
if txte.find('factor')==-1: | |
# expand + simplify when not explicitly factoriing. | |
expr = sympy.expand(expr) | |
expr = sympy.simplify(expr) | |
if haseq: | |
if args.variable: | |
v = s.syms.get(args.variable) | |
if args.numeric: | |
r = sympy.nsolve(expr, v, (1,)) | |
else: | |
r = sympy.solve(expr, v) | |
print(r) | |
else: | |
# solving for all variables | |
for k in s.solvable: | |
v = s.syms.get(k) | |
if not v: continue | |
if type(v)==dict: continue | |
try: | |
if args.numeric: | |
r = sympy.nsolve(expr, v, (1,)) | |
else: | |
r = sympy.solve(expr, v) | |
print(k, "->", r) | |
except Exception as e: | |
print(e) | |
elif exprname: | |
state[exprname] = expr | |
else: | |
# print expanded + simplified expression | |
print(expr) | |
def searchdocs(query): | |
for k in dir(sympy): | |
docs = getattr(sympy, k).__doc__ | |
if docs and docs.find(query)>=0 or k==query: | |
print(k) | |
print(getattr(sympy, k).__doc__) | |
def main(): | |
import argparse | |
parser = argparse.ArgumentParser(description='symbolic calculations', epilog=""" | |
examples: | |
f = a*x**2 + b*x + c | |
f.subs(x, x+1) + f.subs(x, x-1) == 0 | |
factor(x**2-1) | |
(x**2-1).factor() | |
-n "sqrt(x)==cos(x)" | |
-s "3*a+4*b+7*c==1" "11*a+13*b+2*c==1" | |
""") | |
parser.add_argument('--query', '-q', type=str, help='search sympy documentation') | |
parser.add_argument('--variable', '-x', type=str, help='solve for the specified variable') | |
parser.add_argument('--freevariables', '-v', type=str, help='free variables', default='') | |
parser.add_argument('--system', '-s', action='store_true', help='solve as a system of equations') | |
parser.add_argument('--numeric', '-n', action='store_true', help='approximate solution') | |
parser.add_argument('expressions', nargs='*', type=str, help='expressions') | |
args = parser.parse_args() | |
if args.query: | |
searchdocs(args.query) | |
return | |
state = dict() | |
if args.system: | |
evalsystem(state, args, args.expressions) | |
else: | |
for e in args.expressions: | |
evalexpr(state, args, e) | |
if __name__=='__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment