Last active
March 7, 2016 01:27
-
-
Save shashir/64db06285ffa3cf2c28b to your computer and use it in GitHub Desktop.
Decision tree learning using ID3 algorithm
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// https://github.com/shashir/cs7641/blob/master/lessons/ml_p1_lesson_01_decision_trees.md#id3-algorithm | |
import scala.collection.mutable.{Set => MSet} | |
case class Tree[T]( | |
value: T, | |
children: MSet[Tree[T]] = MSet.empty[Tree[T]] | |
) { | |
override def toString(): String = { | |
value + | |
( | |
if (!children.isEmpty) { | |
(if (!children.init.isEmpty) { | |
"\n|___" + children.init.map(_.toString.split("\n").mkString("\n| ")).mkString("\n|___") | |
} else "") + | |
"\n|___" + children.last.toString.split("\n").mkString("\n ") + "\n." | |
} else "" | |
) | |
} | |
} | |
case class Instance( | |
features: Map[String, String] = Map[String, String](), | |
label: Option[String] = None | |
) | |
object DataSet { | |
def labels(samples: Iterable[Instance]): Set[String] = { | |
return samples.map(_.label.get).toSet | |
} | |
def labelCount(samples: Iterable[Instance]): Map[String, Int] = { | |
return samples.groupBy(_.label.get).mapValues(_.size) | |
} | |
def entropy(samples: Iterable[Instance]): Double = { | |
if (samples.isEmpty) return 0.0 | |
return {for (label <- labels(samples).toSeq) yield { | |
val filteredSamples = samples.filter{_.label.get == label} | |
val p = filteredSamples.size.toDouble / samples.size | |
-p * Math.log(p) | |
}}.sum | |
} | |
def featureValues(samples: Iterable[Instance], feature: String): Set[String] = { | |
return samples.filter(_.features.contains(feature)).map(_.features(feature)).toSet | |
} | |
def featureEntropy(samples: Iterable[Instance], feature: String): Double = { | |
if (samples.isEmpty) return 0.0 | |
return {for (value <- featureValues(samples, feature).toSeq) yield { | |
val filteredSamples = samples.filter{ sample => sample.features.contains(feature) && sample.features(feature) == value } | |
filteredSamples.size.toDouble * entropy(filteredSamples) | |
}}.sum | |
} | |
} | |
case class BucketStats(feature: String, value: String, samples: Int, labelCount: Map[String, Int], entropy: Double) | |
case class Bucket( | |
feature: String, | |
value: String, | |
samples: Iterable[Instance], | |
bucketStats: BucketStats | |
) { | |
def this(feature: String, value: String, samples: Iterable[Instance]) = | |
this( | |
feature, | |
value, | |
samples, | |
BucketStats(feature, value, samples.size, DataSet.labelCount(samples), DataSet.entropy(samples)) | |
) | |
override def toString(): String = bucketStats.toString() | |
} | |
class DecisionTree( | |
data: Iterable[Instance], | |
features: Set[String], | |
maxDepth: Int = Int.MaxValue | |
) { | |
val tree: Tree[Bucket] = Tree(new Bucket(null, null, data)) | |
DecisionTree.decisionTreeBuilder(tree, features, maxDepth) | |
def score(input: Instance, node: Tree[Bucket] = this.tree): String = { | |
if (node.children.isEmpty) { | |
return node.value.bucketStats.labelCount.maxBy(_._2)._1 | |
} else { | |
val childBuckets: MSet[Tree[Bucket]] = node.children.filter { child: Tree[Bucket] => | |
val (feature, value) = (child.value.feature, child.value.value) | |
(input.features.contains(feature) && input.features(feature) == value) | |
} | |
if (childBuckets.isEmpty) { | |
return node.value.bucketStats.labelCount.maxBy(_._2)._1 | |
} else { | |
return score(input, childBuckets.head) | |
} | |
} | |
} | |
} | |
object DecisionTree extends App { | |
def decisionTreeBuilder( | |
node: Tree[Bucket], | |
unusedFeatures: Set[String], | |
maxDepth: Int = Int.MaxValue, | |
depth: Int = 0 | |
): Unit = { | |
if (maxDepth == depth) { | |
return | |
} | |
val samples: Iterable[Instance] = node.value.samples | |
if (unusedFeatures.isEmpty) { | |
return | |
} | |
val entropy: Double = DataSet.entropy(samples) | |
if (entropy == 0.0) { | |
return | |
} | |
val (bestFeature: String, _: Double) = unusedFeatures.map { feature: String => | |
(feature, DataSet.featureEntropy(samples, feature)) | |
}.minBy(_._2) | |
val featureValues: Set[String] = DataSet.featureValues(samples, bestFeature) | |
for (value <- featureValues) { | |
val filteredSamples = samples.filter{ sample => sample.features.contains(bestFeature) && sample.features(bestFeature) == value } | |
if (filteredSamples.size < samples.size && filteredSamples.size > 0) { | |
val childBucket: Tree[Bucket] = Tree(new Bucket(bestFeature, value, filteredSamples)) | |
node.children.add(childBucket) | |
decisionTreeBuilder(childBucket, unusedFeatures - bestFeature, maxDepth, depth + 1) | |
} | |
} | |
} | |
val data = Set( | |
Instance(Map("A" -> "a", "B" -> "a", "C" -> "a"), Some("f")), | |
Instance(Map("A" -> "a", "B" -> "a", "C" -> "b"), Some("t")), | |
Instance(Map("A" -> "a", "B" -> "b", "C" -> "a"), Some("f")), | |
Instance(Map("A" -> "a", "B" -> "b", "C" -> "b"), Some("f")), | |
Instance(Map("A" -> "b", "B" -> "a", "C" -> "a"), Some("t")), | |
Instance(Map("A" -> "b", "B" -> "a", "C" -> "b"), Some("f")), | |
Instance(Map("A" -> "b", "B" -> "b", "C" -> "a"), Some("t")), | |
Instance(Map("A" -> "b", "B" -> "b", "C" -> "a"), Some("f")) | |
) | |
val tree: DecisionTree = new DecisionTree(data, Set("A", "B", "C"), 4) | |
println(tree.tree) | |
data.foreach { instance: Instance => | |
println(instance, tree.score(instance)) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment