Skip to content

Instantly share code, notes, and snippets.

@shashir
Last active March 7, 2016 01:27
Show Gist options
  • Save shashir/64db06285ffa3cf2c28b to your computer and use it in GitHub Desktop.
Save shashir/64db06285ffa3cf2c28b to your computer and use it in GitHub Desktop.
Decision tree learning using ID3 algorithm
// https://github.com/shashir/cs7641/blob/master/lessons/ml_p1_lesson_01_decision_trees.md#id3-algorithm
import scala.collection.mutable.{Set => MSet}
case class Tree[T](
value: T,
children: MSet[Tree[T]] = MSet.empty[Tree[T]]
) {
override def toString(): String = {
value +
(
if (!children.isEmpty) {
(if (!children.init.isEmpty) {
"\n|___" + children.init.map(_.toString.split("\n").mkString("\n| ")).mkString("\n|___")
} else "") +
"\n|___" + children.last.toString.split("\n").mkString("\n ") + "\n."
} else ""
)
}
}
case class Instance(
features: Map[String, String] = Map[String, String](),
label: Option[String] = None
)
object DataSet {
def labels(samples: Iterable[Instance]): Set[String] = {
return samples.map(_.label.get).toSet
}
def labelCount(samples: Iterable[Instance]): Map[String, Int] = {
return samples.groupBy(_.label.get).mapValues(_.size)
}
def entropy(samples: Iterable[Instance]): Double = {
if (samples.isEmpty) return 0.0
return {for (label <- labels(samples).toSeq) yield {
val filteredSamples = samples.filter{_.label.get == label}
val p = filteredSamples.size.toDouble / samples.size
-p * Math.log(p)
}}.sum
}
def featureValues(samples: Iterable[Instance], feature: String): Set[String] = {
return samples.filter(_.features.contains(feature)).map(_.features(feature)).toSet
}
def featureEntropy(samples: Iterable[Instance], feature: String): Double = {
if (samples.isEmpty) return 0.0
return {for (value <- featureValues(samples, feature).toSeq) yield {
val filteredSamples = samples.filter{ sample => sample.features.contains(feature) && sample.features(feature) == value }
filteredSamples.size.toDouble * entropy(filteredSamples)
}}.sum
}
}
case class BucketStats(feature: String, value: String, samples: Int, labelCount: Map[String, Int], entropy: Double)
case class Bucket(
feature: String,
value: String,
samples: Iterable[Instance],
bucketStats: BucketStats
) {
def this(feature: String, value: String, samples: Iterable[Instance]) =
this(
feature,
value,
samples,
BucketStats(feature, value, samples.size, DataSet.labelCount(samples), DataSet.entropy(samples))
)
override def toString(): String = bucketStats.toString()
}
class DecisionTree(
data: Iterable[Instance],
features: Set[String],
maxDepth: Int = Int.MaxValue
) {
val tree: Tree[Bucket] = Tree(new Bucket(null, null, data))
DecisionTree.decisionTreeBuilder(tree, features, maxDepth)
def score(input: Instance, node: Tree[Bucket] = this.tree): String = {
if (node.children.isEmpty) {
return node.value.bucketStats.labelCount.maxBy(_._2)._1
} else {
val childBuckets: MSet[Tree[Bucket]] = node.children.filter { child: Tree[Bucket] =>
val (feature, value) = (child.value.feature, child.value.value)
(input.features.contains(feature) && input.features(feature) == value)
}
if (childBuckets.isEmpty) {
return node.value.bucketStats.labelCount.maxBy(_._2)._1
} else {
return score(input, childBuckets.head)
}
}
}
}
object DecisionTree extends App {
def decisionTreeBuilder(
node: Tree[Bucket],
unusedFeatures: Set[String],
maxDepth: Int = Int.MaxValue,
depth: Int = 0
): Unit = {
if (maxDepth == depth) {
return
}
val samples: Iterable[Instance] = node.value.samples
if (unusedFeatures.isEmpty) {
return
}
val entropy: Double = DataSet.entropy(samples)
if (entropy == 0.0) {
return
}
val (bestFeature: String, _: Double) = unusedFeatures.map { feature: String =>
(feature, DataSet.featureEntropy(samples, feature))
}.minBy(_._2)
val featureValues: Set[String] = DataSet.featureValues(samples, bestFeature)
for (value <- featureValues) {
val filteredSamples = samples.filter{ sample => sample.features.contains(bestFeature) && sample.features(bestFeature) == value }
if (filteredSamples.size < samples.size && filteredSamples.size > 0) {
val childBucket: Tree[Bucket] = Tree(new Bucket(bestFeature, value, filteredSamples))
node.children.add(childBucket)
decisionTreeBuilder(childBucket, unusedFeatures - bestFeature, maxDepth, depth + 1)
}
}
}
val data = Set(
Instance(Map("A" -> "a", "B" -> "a", "C" -> "a"), Some("f")),
Instance(Map("A" -> "a", "B" -> "a", "C" -> "b"), Some("t")),
Instance(Map("A" -> "a", "B" -> "b", "C" -> "a"), Some("f")),
Instance(Map("A" -> "a", "B" -> "b", "C" -> "b"), Some("f")),
Instance(Map("A" -> "b", "B" -> "a", "C" -> "a"), Some("t")),
Instance(Map("A" -> "b", "B" -> "a", "C" -> "b"), Some("f")),
Instance(Map("A" -> "b", "B" -> "b", "C" -> "a"), Some("t")),
Instance(Map("A" -> "b", "B" -> "b", "C" -> "a"), Some("f"))
)
val tree: DecisionTree = new DecisionTree(data, Set("A", "B", "C"), 4)
println(tree.tree)
data.foreach { instance: Instance =>
println(instance, tree.score(instance))
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment