Created
September 27, 2022 22:40
-
-
Save cblp/7ce1cdfaed13d180ae85a5911ee259a4 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <iostream> | |
#include <complex> | |
using namespace std; | |
#include <numbers> | |
using namespace std::numbers; | |
struct Expr { | |
virtual string source() const = 0; | |
virtual shared_ptr<Expr> derive() const = 0; | |
}; | |
struct Lit: Expr { | |
int n; | |
Lit(int n) : n(n) {} | |
string source() const { return to_string(n); } | |
shared_ptr<Expr> derive() const { return make_shared<Lit>(0); } | |
}; | |
struct Var: Expr { | |
string source() const { return "x"; } | |
shared_ptr<Expr> derive() const { return make_shared<Lit>(1); } | |
}; | |
struct Mul: Expr { | |
shared_ptr<Expr> a, b; | |
Mul(shared_ptr<Expr> a, shared_ptr<Expr> b) : a(a), b(b) {} | |
string source() const { return a->source() + " * " + b->source(); } | |
shared_ptr<Expr> derive() const { return nullptr; } | |
}; | |
struct Pow: Expr { | |
shared_ptr<Expr> a; | |
int n; | |
Pow(shared_ptr<Expr> a, int n) : a(a), n(n) {} | |
string source() const { return a->source() + " ^ " + to_string(n); } | |
shared_ptr<Expr> derive() const { | |
return make_shared<Mul>( | |
make_shared<Lit>(n), | |
make_shared<Mul>(make_shared<Pow>(a, n - 1), a->derive()) | |
); | |
} | |
}; | |
struct Cos; | |
struct Sin: Expr { | |
shared_ptr<Expr> a; | |
Sin(shared_ptr<Expr> a) : a(a) {} | |
string source() const { return "sin(" + a->source() + ")"; } | |
shared_ptr<Expr> derive() const { | |
return make_shared<Mul>(make_shared<Cos>(a), a->derive()); | |
} | |
}; | |
struct Cos: Expr { | |
shared_ptr<Expr> a; | |
Cos(shared_ptr<Expr> a) : a(a) {} | |
string source() const { return "cos(" + a->source() + ")"; } | |
shared_ptr<Expr> derive() const { return nullptr; } | |
}; | |
shared_ptr<Expr> sin(shared_ptr<Expr> a) { | |
return make_shared<Sin>(a); | |
} | |
shared_ptr<Expr> pow(shared_ptr<Expr> a, int n) { | |
return make_shared<Pow>(a, n); | |
} | |
string source(function<shared_ptr<Expr>(shared_ptr<Expr>)> f) { | |
return f(make_shared<Var>())->source(); | |
} | |
shared_ptr<Expr> derive(function<shared_ptr<Expr>(shared_ptr<Expr>)> f) { | |
return f(make_shared<Var>())->derive(); | |
} | |
// f x = sin x ** 4 | |
template <typename T> T f(T x) { return pow(sin(x), 4); } | |
int main() { | |
cout << f(pi / 2) << endl; | |
// 1 | |
cout << source(f<shared_ptr<Expr>>) << endl; | |
// sin(x) ^ 4 | |
cout << derive(f<shared_ptr<Expr>>)->source() << endl; | |
// 4 * sin(x) ^ 3 * cos(x) * 1 | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment