-
-
Save 7shi/7078516eaa5980334271 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import scala.language.implicitConversions | |
sealed trait Expr { | |
def +(that: Expr): Expr = Add(List(this, that)) | |
def *(that: Expr): Expr = Mul(List(this, that)) | |
def +(that: Int ): Expr = this + N( that) | |
def -(that: Int ): Expr = this + N(-that) | |
def *(that: Int ): Expr = this * N( that) | |
override def toString: String = this match { | |
case n: N => n.toString | |
case Var(x, N( 1,1), N(1,1)) => x | |
case Var(x, N(-1,1), N(1,1)) => "-" + x | |
case Var(x, a , N(1,1)) => a.rstr(" ") + x | |
case Var(x, a , n ) => Var(x, a) + "^" + n | |
case Add(Nil) => "" | |
case Add(List(Add(xs))) => "(" + Add(xs) + ")" | |
case Add(List(x)) => x.toString | |
case Add(x::xs) | |
if xs.head.isNeg => Add(List(x)).toString + Add(xs) | |
case Add(x::xs) => Add(List(x)) + "+" + Add(xs) | |
case Mul(Nil) => "" | |
case Mul(List(Add(xs))) => "(" + Add(xs) + ")" | |
case Mul(List(Mul(xs))) => "(" + Mul(xs) + ")" | |
case Mul(List(x)) => x.toString | |
case Mul(x::xs) => Mul(List(x)) + "*" + Mul(xs) | |
} | |
def eval: N = (this: @unchecked) match { | |
case n: N => n | |
case Add(xs) => xs.map(_.eval).reduceLeft(_ + _) | |
case Mul(xs) => xs.map(_.eval).reduceLeft(_ * _) | |
} | |
def isNeg: Boolean = this match { | |
case n: N if n < 0 => true | |
case Var(_, a, _) if a < 0 => true | |
case _ => false | |
} | |
def <(that: Expr): Boolean = (this, that) match { | |
case (Var("x", _, n1), Var("x", _, n2)) => n1 < n2 | |
case (Var("x", _, _), _) => false | |
case (_, Var("x", _, _)) => true | |
case _ => true | |
} | |
def >=(that: Expr): Boolean = !(this < that) | |
def sort: Expr = this match { | |
case Add(xs) => { | |
def f(list: List[Expr]): List[Expr] = list match { | |
case List() => List() | |
case x::xs => { | |
val xs1 = for (x1 <- xs if x1 >= x) yield x1.sort | |
val xs2 = for (x2 <- xs if x2 < x) yield x2.sort | |
f(xs1) ++ List(x) ++ f(xs2) | |
} | |
} | |
Add(f(xs)) | |
} | |
case Mul(xs) => Mul(xs.map(_.sort)) | |
case _ => this | |
} | |
private def flatten(list: List[Expr]): List[Expr] = list match { | |
case List() => List() | |
case Add(xs1)::xs2 => flatten(xs1 ++ xs2) | |
case x::xs => x :: flatten(xs) | |
} | |
private def add(list: List[Expr]): Expr = list match { | |
case List() => N(0) | |
case List(x) => x | |
case _ => Add(list) | |
} | |
def simplify: Expr = this match { | |
case Add(xs) => { | |
def getxs(e: Expr) = (e: @unchecked) match { | |
case Add(xs) => xs | |
} | |
def f(list: List[Expr]): List[Expr] = list match { | |
case List() => List() | |
case N(0,_)::xs => f(xs) | |
case Var(_,N(0,_),_)::xs => f(xs) | |
case List(x) => List(x.simplify) | |
case (a1: N)::(a2: N)::zs => f((a1 + a2)::zs) | |
case Var("x",a1,n1)::Var("x",a2,n2)::zs if n1 == n2 => f(x(a1 + a2, n1)::zs) | |
case x::xs => x.simplify::f(xs) | |
} | |
add(f(getxs(Add(flatten(xs)).sort))) | |
} | |
case Mul(xs) => Mul(xs.map(_.simplify)) | |
case _ => this | |
} | |
def multiply(that: Expr): Expr = (this, that) match { | |
case (n1: N, n2: N) => n1 * n2 | |
case (n1: N, Var(x, a2, n2)) => Var(x, n1 * a2, n2) | |
case (Var(x, a1, n1), n2: N) => Var(x, a1 * n2, n1) | |
case (Var(x, a1, n1), Var(y, a2, n2)) if x == y => Var(x, a1 * a2, n1 + n2) | |
case (Var(x, a1, n1), Var(y, a2, n2)) if x != y => Var(x, a1, n1) * Var(y, a2, n2) | |
case (Add(xs1), Add(xs2)) => Add((for(x1 <- xs1; x2 <- xs2) yield x1.multiply(x2))) | |
case (Add(xs1), x2) => Add((for(x1 <- xs1) yield x1.multiply(x2))) | |
case (x1, Add(xs2)) => Add((for(x2 <- xs2) yield x1.multiply(x2))) | |
case (Mul(xs1), Mul(xs2)) => Mul(xs1 ++ xs2) | |
case (Mul(xs1), x2) => Mul(xs1 :+ x2) | |
case (x1, Mul(xs2)) => Mul(x1 :: xs2) | |
} | |
private def mul(list: List[Expr]): Expr = list match { | |
case List() => N(1) | |
case List(x) => x | |
case _ => Mul(list) | |
} | |
def expand: Expr = this match { | |
case Mul(xs) => { | |
def f(list: List[Expr]): List[Expr] = list match { | |
case List() => List() | |
case List(x) => List(x.expand) | |
case (x::y::xs) => x.multiply(y) :: xs | |
} | |
mul(f(xs)) | |
} | |
case Add(xs) => { | |
def f(list: List[Expr]): List[Expr] = list match { | |
case List() => List() | |
case x::xs => { | |
val x2 = x.expand | |
if (x != x2) x2::xs else x::f(xs) | |
} | |
} | |
add(f(xs)) | |
} | |
case _ => this | |
} | |
def expandAll: Expr = { | |
val x2 = expand | |
if (this != x2) x2.expandAll else this | |
} | |
def differentiate(x: String): Expr = this match { | |
case Add(ys) => Add((for(y <- ys) yield y.differentiate(x))) | |
case Var(y, a, N(1,1)) if x == y => a | |
case Var(y, a, n) if x == y => Var(x, a * n, n - 1) | |
case Var(_, _, _) => N(0) | |
case _: N => N(0) | |
} | |
def integrate(x: String): Expr = this match { | |
case Add(ys) => Add((for(y <- ys) yield y.integrate(x)) ++ List(Var("C"))) | |
case Var(y, a, n) if x == y => Var(x, a / (n + 1), n + 1) | |
case Var(y, a, n) if x != y => Var(y, a, n) * Var(x) | |
case n: N => Var(x, n) | |
} | |
} | |
// N is derived from Rational with some fixes. | |
// https://sites.google.com/site/scalajp/home/documentation/scala-by-example/chapter6 | |
case class N(numer: Int, denom: Int = 1) extends Expr with Ordered[N] { | |
def reduce: N = { | |
def gcd(x: Int, y: Int): Int = { | |
if (x == 0) y | |
else if (x < 0) gcd(-x, y) | |
else if (y < 0) -gcd( x, -y) | |
else gcd(y % x, x) | |
} | |
val g = gcd(numer, denom) | |
N(numer / g, denom / g) | |
} | |
def +(that: N) = N( | |
numer * that.denom + that.numer * denom, | |
denom * that.denom).reduce | |
def -(that: N) = N( | |
numer * that.denom - that.numer * denom, | |
denom * that.denom).reduce | |
def *(that: N) = N( | |
numer * that.numer, | |
denom * that.denom).reduce | |
def /(that: N) = N( | |
numer * that.denom, | |
denom * that.numer).reduce | |
def compare(that: N): Int = (this - that).numer | |
override def equals(other: Any): Boolean = other match { | |
case that: N => numer == that.numer && denom == that.denom | |
case that: Int => numer == that && denom == 1 | |
case _ => false | |
} | |
override def toString: String = denom match { | |
case 1 => numer.toString | |
case _ => numer + "/" + denom | |
} | |
def rstr(s: String): String = denom match { | |
case 1 => numer.toString | |
case _ => this + s | |
} | |
} | |
implicit def NToInt(n: Int): N = N(n) | |
case class Var(x: String, a: N = 1, n: N = 1) extends Expr | |
def x(a: N = 1, n: N = 1): Var = Var("x", a, n) | |
case class Add(xs: List[Expr]) extends Expr { | |
override def +(that: Expr): Expr = Add(xs :+ that) | |
} | |
case class Mul(xs: List[Expr]) extends Expr { | |
override def *(that: Expr): Expr = Mul(xs :+ that) | |
} | |
def test(tag: String, v: Any, e: Any) = { | |
if (v == e) { | |
println("[OK] " + tag) | |
} else { | |
println("[NG] " + tag) | |
println(" value : " + v) | |
println(" expected: " + e) | |
} | |
} | |
test("eval 1", (N(1)+1).eval, 1+1) | |
test("eval 2", (N(2)+3).eval, 2+3) | |
test("eval 3", (N(5)-3).eval, 5-3) | |
test("eval 4", (N(3)*4).eval, 3*4) | |
test("eval 5", (N(1)+N(2)*3).eval, 1+2*3) | |
test("eval 6", ((N(1)+2)*3).eval, (1+2)*3) | |
test("str 1", (N(1)+2+3).toString, "1+2+3") | |
test("str 2", (N(1)-2-3).toString, "1-2-3") | |
test("str 3", (N(1)*2*3).toString, "1*2*3") | |
test("str 4", (N(1)+N(2)*3).toString, "1+2*3") | |
test("str 5", (Add(List(N(1)+2,N(3)))).toString, "(1+2)+3") | |
test("str 6", ((N(1)+2)*3).toString, "(1+2)*3") | |
test("str 7", (Mul(List(N(1)*2,N(3)))).toString, "(1*2)*3") | |
test("equal", N(1)+2, N(1)+2) | |
test("x 1", (x()+1).toString, "x+1") | |
test("x 2", (x(1,3)+x(-1,2)+x(-2)+1).toString, "x^3-x^2-2x+1") | |
test("xlt 1", x() < x(1,2), true) | |
test("xlt 2", N(1) < x(), true) | |
test("xsort 1", { | |
val f = x()+1+x(1,2) | |
(f.toString, f.sort.toString) | |
},("x+1+x^2", "x^2+x+1")) | |
test("xsort 2", { | |
val f = (N(5)+x(2))*(x()+1+x(1,2)) | |
(f.toString, f.sort.toString) | |
},("(5+2x)*(x+1+x^2)", "(2x+5)*(x^2+x+1)")) | |
test("xsimplify 1", { | |
val f = x(2)+3+x(4,2)+x()+1+x(1,2) | |
(f.toString, f.simplify.toString) | |
},("2x+3+4x^2+x+1+x^2", "5x^2+3x+4")) | |
test("xsimplify 2", { | |
val f = (x()+0+x(2))*Add(List(x(1,2),N(1)+x(2,2),N(2))) | |
(f.toString, f.simplify.toString) | |
},("(x+0+2x)*(x^2+(1+2x^2)+2)", "3x*(3x^2+3)")) | |
test("xsimplify 3", { | |
val f = x()+1+x(0,2)+x()+1+x(-2)-2 | |
(f.toString, f.simplify.toString) | |
},("x+1+0x^2+x+1-2x-2", "0")) | |
test("multiply 1", { | |
val f1 = N(2) | |
val f2 = N(3) | |
(f1.toString, f2.toString, f1.multiply(f2).toString) | |
},("2", "3", "6")) | |
test("multiply 2", { | |
val f1 = N(2) | |
val f2 = x(3,2) | |
(f1.toString, f2.toString, f1.multiply(f2).toString) | |
},("2", "3x^2", "6x^2")) | |
test("multiply 3", { | |
val f1 = x(2,3) | |
val f2 = x(3,4) | |
(f1.toString, f2.toString, f1.multiply(f2).toString) | |
},("2x^3", "3x^4", "6x^7")) | |
test("multiply 4", { | |
val f1 = N(2) | |
val f2 = x()+x(2,2)+3 | |
(f1.toString, f2.toString, f1.multiply(f2).toString) | |
},("2", "x+2x^2+3", "2x+4x^2+6")) | |
test("multiply 5", { | |
val f1 = x()+1 | |
val f2 = x(2)+3 | |
val f3 = f1.multiply(f2) | |
val f4 = f3.simplify | |
(f1.toString, f2.toString, f3.toString, f4.toString) | |
},("x+1", "2x+3", "2x^2+3x+2x+3", "2x^2+5x+3")) | |
test("expand 1", { | |
val f = (x()+1)*(x()+2)*(x()+3) | |
(f.toString, f.expand.toString) | |
},("(x+1)*(x+2)*(x+3)", "(x^2+2x+x+2)*(x+3)")) | |
test("expand 2", { | |
val f = N(1)+(x()+1)*(x()+2)*(x()+3) | |
(f.toString, f.expand.toString) | |
},("1+(x+1)*(x+2)*(x+3)", "1+(x^2+2x+x+2)*(x+3)")) | |
test("expandAll", { | |
val f1 = N(1)+(x()+1)*(x()+2)*(x()+3) | |
val f2 = f1.expandAll | |
val f3 = f2.simplify | |
(f1.toString, f2.toString, f3.toString) | |
},("1+(x+1)*(x+2)*(x+3)", | |
"1+(x^3+3x^2+2x^2+6x+x^2+3x+2x+6)", | |
"x^3+6x^2+11x+7")) | |
test("differentiate", { | |
val f = x(1,3)+x(1,2)+x()+1 | |
(f.toString, f.differentiate("x").toString) | |
},("x^3+x^2+x+1", "3x^2+2x+1+0")) | |
test("integrate", { | |
val f = x(1,2)+x(2)+1 | |
(f.toString, f.integrate("x").toString) | |
},("x^2+2x+1", "1/3 x^3+x^2+x+C")) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment