-
-
Save darkfrog26/60a42e8e803ab84dce915245dafb1c03 to your computer and use it in GitHub Desktop.
Linear regression with pure Scala
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import scala.concurrent._ | |
import scala.concurrent.ExecutionContext.Implicits.global | |
import scala.concurrent.duration._ | |
// based on the original work: https://gist.github.com/otobrglez/08e9064209c9fc777ea5/d8000c300d2fd72db5a3445b8a93612680acaabb | |
// For kicks, I cleaned this up | |
object Regression { | |
def linear(pairs: Seq[Data]): Future[Regression] = for { | |
x1 <- Future(pairs.foldLeft(0.0)((sum, d) => sum + d.x)) | |
y1 <- Future(pairs.foldLeft(0.0)((sum, d) => sum + d.y)) | |
x2 <- Future(pairs.foldLeft(0.0)((sum, d) => sum + math.pow(d.x, 2.0))) | |
y2 <- Future(pairs.foldLeft(0.0)((sum, d) => sum + math.pow(d.y, 2.0))) | |
xy <- Future(pairs.foldLeft(0.0)((sum, d) => sum + d.x * d.y)) | |
size = pairs.size | |
dn = size * x2 - math.pow(x1, 2.0) | |
slope <- Future { | |
((size * xy) - (x1 * y1)) / dn | |
} | |
intercept <- Future { | |
((y1 * x2) - (x1 * xy)) / dn | |
} | |
t1 <- Future { | |
((size * xy) - (x1 * y1)) * ((size * xy) - (x1 * y1)) | |
} | |
t2 <- Future { | |
(size * x2) - math.pow(x1, 2) | |
} | |
t3 <- Future { | |
(size * y2) - math.pow(y1, 2) | |
} | |
} yield { | |
assert(dn != 0.0, "Can't solve the system!") | |
if (t2 * t3 != 0.0) { | |
Regression(slope, intercept, t1 / (t2 * t3)) | |
} else { | |
Regression(slope, intercept, 0.0) | |
} | |
} | |
def main(args: Array[String]): Unit = { | |
val data = Seq( | |
Data(1.0, 4.0), | |
Data(2.0, 6.0), | |
Data(4.0, 12.0), | |
Data(5.0, 15.0), | |
Data(10.0, 34.0), | |
Data(20.0, 68.0) | |
) | |
val future = linear(data) | |
val regression = Await.result(future, Duration.Inf) | |
println(regression) | |
println(s"Solved? 10.0 -> ${regression(10.0)}") | |
} | |
} | |
case class Data(x: Double, y: Double) | |
case class Regression(slope: Double, intercept: Double, probability: Double) { | |
def apply(value: Double): Double = slope * value + intercept | |
override def toString: String = s"slope: $slope, intercept: $intercept, probability: $probability" | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment