Created
July 12, 2015 04:05
-
-
Save shigemk2/817e6c2267ec08929cf6 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
sealed trait Expr | |
case class N(n: Int) extends Expr | |
case class Var(x: String, a: Int, n: Int) extends Expr | |
case class Add(n: Expr*) extends Expr | |
case class Mul(n: Expr*) extends Expr | |
def x(a: Int, n: Int): Var = { | |
Var("x", a, n) | |
} | |
def eval(e: Expr): Int = (e: @unchecked) match { | |
case N(x) => x | |
case Add(xs @_*) => xs.map(x => eval(x)).sum | |
case Mul(xs @_*) => xs.map(x => eval(x)).product | |
} | |
def str(e: Expr): String = e match { | |
case N(x) => x.toString | |
case Var(x, 1, 1) => x | |
case Var(x, -1, 1) => "-" ++ x | |
case Var(x, a, 1) => a.toString ++ x | |
case Var(x, a, n) => str(Var(x, a, 1)) ++ "^" ++ n.toString | |
case Add() => "" | |
case Add(Add(xs@_*)) => "(" ++ str(Add(xs: _*)) ++ ")" | |
case Add(x) => str(x) | |
case Add(x, xs@_*) | |
if isneg(xs.head) => str(Add(x)) ++ str(Add(xs: _*)) | |
case Add(x, xs@_*) => str(Add(x)) ++ "+" ++ str(Add(xs: _*)) | |
case Mul() => "" | |
case Mul(Add(xs@_*)) => "(" ++ str(Add(xs: _*)) ++ ")" | |
case Mul(Mul(xs@_*)) => "(" ++ str(Mul(xs: _*)) ++ ")" | |
case Mul(x) => str(x) | |
case Mul(x, xs@_*) => str(Mul(x)) ++ "*" ++ str(Mul(xs: _*)) | |
} | |
def isneg(e: Expr): Boolean = e match { | |
case N(n) if n < 0 => true | |
case Var(_, a, _) if a < 0 => true | |
case _ => false | |
} | |
def xlt(x: Expr, y: Expr): Boolean = (x,y) match { | |
case (Var("x", _, n1), Var("x", _, n2)) => (n1 < n2) | |
case (Var("x", _, _), _) => false | |
case (_, Var("x", _, _)) => true | |
case (_, _) => true | |
} | |
def xsort(xs: Expr): Expr = xs match { | |
case Add(xs@_*) => { | |
def f(xs: List[Expr]): List[Expr] = xs match { | |
case List() => List() | |
case (x::xs) => { | |
val xs1 = for (x1 <- xs if ! xlt(x1, x)) yield xsort(x1) | |
val xs2 = for (x2 <- xs if xlt(x2, x)) yield xsort(x2) | |
f(xs1) ++ List(x) ++ f(xs2) | |
} | |
} | |
Add(f(xs.toList): _*) | |
} | |
case Mul(xs@_*) => Mul(xs.map(x => xsort(x)): _*) | |
case xs => xs | |
} | |
def flatten(xs: List[Expr]): List[Expr] = xs match { | |
case List() => List() | |
case (Add(xs1@_*)::xs2) => flatten(xs1.toList ++ xs2) | |
case (x::xs) => x :: flatten(xs) | |
} | |
def add(xs: List[Expr]): Expr = xs match { | |
case List() => N(0) | |
case List(xs) => xs | |
case xs => Add(xs: _*) | |
} | |
def xsimplify(xs: Expr): Expr = xs match { | |
case Add(xs@_*) => { | |
def getxs(xs: Expr) = (xs: @unchecked) match { | |
case Add(xs @_*) => xs | |
} | |
def f(xs: List[Expr]): List[Expr] = xs match { | |
case List() => List() | |
case (N(0)::xs) => f(xs) | |
case (Var(_,0,_)::xs) => f(xs) | |
case List(x) => List(xsimplify(x)) | |
case (N(a1)::N(a2)::zs) => f(N(a1 + a2)::zs) | |
case (Var("x",a1,n1)::Var("x",a2,n2)::zs) if n1 == n2 => f(x(a1 + a2, n1)::zs) | |
case (x::xs) => xsimplify(x)::f(xs) | |
} | |
add(f(getxs(xsort(Add(flatten(xs.toList): _*))).toList)) | |
} | |
case Mul(xs@_*) => Mul(xs.map(x => xsimplify(x)): _*) | |
case xs => xs | |
} | |
def multiply(xs1: Expr, xs2: Expr): Expr = (xs1, xs2) match { | |
case (N(n1), N(n2)) => N(n1 * n2) | |
case (N(n1), Var(x, a2, n2)) => Var(x, (n1 * a2), n2) | |
case (Var(x, a1, n1), N(n2)) => Var(x, (a1 * n2), n1) | |
case (Var(x, a1, n1), Var(y, a2, n2)) if x == y => Var(x, (a1 * a2), (n1 + n2)) | |
case (Var(x, a1, n1), Var(y, a2, n2)) if x != y => Mul(Var(x, a1, n1), Var(y, a2, n2)) | |
case (Add(xs1@_*), Add(xs2@_*)) => Add((for(x1 <- xs1; x2 <- xs2) yield multiply(x1, x2)): _*) | |
case (Add(xs1@_*), x2) => Add((for(x1 <- xs1) yield multiply(x1, x2)): _*) | |
case (x1, Add(xs2@_*)) => Add((for(x2 <- xs2) yield multiply(x1, x2)): _*) | |
case (Mul(xs1@_*), Mul(xs2@_*)) => Mul(xs1.toList ++ xs2.toList:_*) | |
case (Mul(xs1@_*), xs2) => Mul(xs1.toList :+ xs2:_*) | |
case (xs1, Mul(xs2@_*)) => Mul(xs1 :: xs2.toList:_*) | |
} | |
def mul(xs: List[Expr]): Expr = xs match { | |
case List() => N(1) | |
case List(xs) => xs | |
case xs => Mul(xs: _*) | |
} | |
def expand(xs: Expr): Expr = xs match { | |
case Mul(xs@_*) => { | |
def f(xs: List[Expr]): List[Expr] = xs match { | |
case List() => List() | |
case List(x) => List(expand(x)) | |
case (x::y::xs) => multiply(x,y) :: xs | |
} | |
Mul(f(xs.toList): _*) | |
} | |
case Add(xs@_*) => { | |
def f(xs: List[Expr]): List[Expr] = xs match { | |
case List() => List() | |
case (x::xs) if x != expand(x) => expand(x) :: xs | |
case (x::xs) if x == expand(x) => x :: f(xs) | |
} | |
Add(f(xs.toList): _*) | |
} | |
case xs => xs | |
} | |
def expandAll(x: Expr): Expr = x match { | |
case x if x != expand(x) => expandAll(expand(x)) | |
case x if x == expand(x) => x | |
} | |
println(eval(Add(N(1),N(2))) == 1+2) | |
println(eval(Add(N(2),N(3))) == 2+3) | |
println(eval(Add(N(5),N(-3))) == 5-3) | |
println(eval(Mul(N(3),N(4))) == 3*4) | |
println(eval(Add(N(1),Mul(N(2),N(3)))) == 1+2*3) | |
println(eval(Mul(Add(N(1),N(2)),N(3))) == (1+2)*3) | |
println(str(Add(N(1),N(2),N(3))) == "1+2+3") | |
println(str(Add(N(1),N(-2),N(-3))) == "1-2-3") | |
println(str(Mul(N(1),N(2),N(3))) == "1*2*3") | |
println(str(Add(N(1),Mul(N(2),N(3)))) == "1+2*3") | |
println(str(Mul(N(1),N(2),N(3))) == "1*2*3") | |
println(str(Add(Add(N(1),N(2)),N(3))) == "(1+2)+3") | |
println(str(Mul(Add(N(1),N(2)),N(3))) == "(1+2)*3") | |
println(str(Mul(Mul(N(1),N(2)),N(3))) == "(1*2)*3") | |
println(Add(N(1),N(2)) == Add(N(1),N(2))) | |
println(str(Add(x(1,1),N(1))) == "x+1") | |
println(str(Add(x(1,3),x(-1,2),x(-2,1),N(1))) == "x^3-x^2-2x+1") | |
val f = Mul(Add(N(5),x(2,1)),Add(x(1,2),x(1,1),N(1),x(3,3))) | |
println(str(f) == "(5+2x)*(x^2+x+1+3x^3)") | |
println(str(xsort(f)) == "(2x+5)*(3x^3+x^2+x+1)") | |
val g1 = Add(x(2,1),N(3),x(4,2),x(1,1),N(1),x(1,2)) | |
println(str(g1) == "2x+3+4x^2+x+1+x^2") | |
println(str(xsimplify(g1)) == "5x^2+3x+4") | |
val g2 = Mul(Add(x(1,1),N(0),x(2,1)),Add(x(1,2),Add(N(1),x(2,2)),N(2))) | |
println(str(g2) == "(x+0+2x)*(x^2+(1+2x^2)+2)") | |
println(str(xsimplify(g2)) == "3x*(3x^2+3)") | |
val g3 = Add(x(1,1),N(1),x(0,2),x(1,1),N(1),x(-2,1),N(-2)) | |
println(str(g3) == "x+1+0x^2+x+1-2x-2") | |
println(str(xsimplify(g3)) == "0") | |
println(str(N(2)) == "2") | |
println(str(N(3)) == "3") | |
println(str(multiply(N(2), N(3))) == "6") | |
println(str(N(2)) == "2") | |
println(str(x(3,2)) == "3x^2") | |
println(str(multiply(N(2), x(3,2))) == "6x^2") | |
println(str(x(2,3)) == "2x^3") | |
println(str(x(3,4)) == "3x^4") | |
println(str(multiply(x(2,3), x(3,4))) == "6x^7") | |
println(str(N(2)) == "2") | |
println(str(Add(x(1,1),x(2,2),N(3))) == "x+2x^2+3") | |
println(str(multiply(N(2), Add(x(1,1),x(2,2),N(3)))) == "2x+4x^2+6") | |
println(str(Add(x(1,1),N(1))) == "x+1") | |
println(str(Add(x(2,1),N(3))) == "2x+3") | |
println(str(multiply(Add(x(1,1),N(1)),Add(x(2,1),N(3)))) == "2x^2+3x+2x+3") | |
println(str(xsimplify(multiply(Add(x(1,1),N(1)),Add(x(2,1),N(3))))) == "2x^2+5x+3") | |
println(str(Mul(Add(x(1,1),N(1)),Add(x(1,1),N(2)),Add(x(1,1),N(3)))) == "(x+1)*(x+2)*(x+3)") | |
println(str(expand(Mul(Add(x(1,1),N(1)),Add(x(1,1),N(2)),Add(x(1,1),N(3))))) == "(x^2+2x+x+2)*(x+3)") | |
println(str(Add(N(1),Mul(Add(x(1,1),N(1)),Add(x(1,1),N(2)),Add(x(1,1),N(3))))) == "1+(x+1)*(x+2)*(x+3)") | |
println(str(expand((Add(N(1),Mul(Add(x(1,1),N(1)),Add(x(1,1),N(2)),Add(x(1,1),N(3))))))) == "1+(x^2+2x+x+2)*(x+3)") | |
println(str(Add(N(1),Mul(Add(x(1,1),N(1)),Add(x(1,1),N(2)),Add(x(1,1),N(3))))) == "1+(x+1)*(x+2)*(x+3)") | |
println(str(expandAll(Add(N(1),Mul(Add(x(1,1),N(1)),Add(x(1,1),N(2)),Add(x(1,1),N(3)))))) == "1+(x^3+3x^2+2x^2+6x+x^2+3x+2x+6)") | |
println(str(xsimplify(expandAll((Add(N(1),Mul(Add(x(1,1),N(1)),Add(x(1,1),N(2)),Add(x(1,1),N(3)))))))) == "x^3+6x^2+11x+7") | |
println(str(xsimplify(Add(N(1),Add(x(1,3),x(3,2),x(2,2),x(6,1),x(1,2),x(3,1),x(2,1),N(6))))) == "x^3+6x^2+11x+7") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment