Skip to content

Instantly share code, notes, and snippets.

@corporatepiyush
Last active July 14, 2025 09:41
Show Gist options
  • Save corporatepiyush/f0d3783cc0b7830e300baf87440e8b06 to your computer and use it in GitHub Desktop.
Save corporatepiyush/f0d3783cc0b7830e300baf87440e8b06 to your computer and use it in GitHub Desktop.
Random forest with decision Tree (Generated by kimi.ai)
(ns random-forest
(:require [clojure.java.io :as io]
[clojure.string :as str])
(:import [java.util Arrays Random]
[jdk.incubator.vector FloatVector VectorSpecies VectorOperators]))
(set! *warn-on-reflection* true)
(set! *unchecked-math* :warn-on-boxed)
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
;;;; 0. Constants & helpers
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
(def ^:const MIN-VECTOR-SIZE 256)
(def ^VectorSpecies F_SPECIES (FloatVector/SPECIES_PREFERRED))
(defn- rand-long ^long [^long n ^Random rng] (.nextLong rng n))
(defn- rand-int ^int [^int n ^Random rng] (.nextInt rng n))
(defn- rand-double ^double [^Random rng] (.nextDouble rng))
(defn- shuffle-into!
[^objects a ^Random rng]
(let [len (alength a)]
(loop [i (dec len)]
(when (pos? i)
(let [j (rand-int (inc i) rng)
tmp (aget a i)]
(aset a i (aget a j))
(aset a j tmp))
(recur (dec i))))
a))
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
;;;; 1. Tree node – flat arrays
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
(defrecord InternalTree
[^ints feature-idx
^doubles threshold
^ints left-idx
^ints right-idx
^ints leaf?
^doubles leaf-pred
^int node-count])
(defn tree-predict
"Single-sample (scalar) prediction."
[^InternalTree t ^doubles x]
(let [^ints f (.feature-idx t)
^doubles s (.threshold t)
^ints l (.left-idx t)
^ints r (.right-idx t)
^ints lf (.leaf? t)
^doubles p (.leaf-pred t)]
(loop [node 0]
(if (== 1 (aget lf node))
(aget p node)
(let [v (aget x (aget f node))]
(recur (if (<= v (aget s node))
(aget l node)
(aget r node))))))))
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
;;;; 2. Categorical helpers
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
(defn encode-categorical
[X cat-idxs]
(into {} (for [f cat-idxs]
[f (zipmap (into #{} (map #(nth % f)) X)
(range))])))
(defn- apply-encoding
[X encoding]
(mapv (fn [row]
(reduce (fn [r [f enc]]
(assoc r f (get enc (nth row f))))
(vec row) encoding))
X))
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
;;;; 3. Vectorised utilities – only when ≥ MIN-VECTOR-SIZE
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
(defn- variance-vec
[^floats y ^ints idxs ^int n]
(if (< n MIN-VECTOR-SIZE)
;; scalar
(let [sum (areduce idxs i s 0.0 (+ s (aget y (aget idxs i))))
mu (/ sum n)]
(/ (areduce idxs i s 0.0
(+ s (Math/pow (- (aget y (aget idxs i)) mu) 2.0)))
(dec n)))
;; SIMD
(let [vlen (.length F_SPECIES)
sum (float 0.0)
sum2 (float 0.0)]
(loop [i 0 s 0.0 s2 0.0]
(if (>= i (- n vlen))
(let [mu (/ (+ s (areduce idxs j t 0.0
(+ t (aget y (aget idxs (+ i j))))))
n)]
(/ (+ s2 (areduce idxs j t 0.0
(+ t (Math/pow (- (aget y (aget idxs (+ i j))) mu) 2.0))))
(dec n)))
(let [vec (.fromArray F_SPECIES y (.fromArray F_SPECIES idxs i))
s (.add (.reduceLanes vec VectorOperators/ADD) s)
sq (.mul vec vec)
s2 (.add (.reduceLanes sq VectorOperators/ADD) s2)]
(recur (+ i vlen) s s2)))))))
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
;;;; 4. Split finding – numeric (vectorised path)
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
(defn- best-num-split
[^floats col ^floats y ^ints idxs ^int n ^Random rng]
(let [unique (vec (distinct (map #(aget col (aget idxs %)) (range n))))
candidates (if (> (count unique) 128)
(take 128 (shuffle unique))
unique)]
(if (< n MIN-VECTOR-SIZE)
;; scalar exhaustive
(reduce (fn [[best-s best-g] s]
(let [[l r] (split-with #(<= (aget col (aget idxs %)) s) (range n))
vl (variance-vec y (int-array l) (count l))
vr (variance-vec y (int-array r) (count r))
gain (- (variance-vec y idxs n)
(+ (* (/ (count l) n) vl)
(* (/ (count r) n) vr)))]
(if (> gain best-g) [s gain] [best-s best-g])))
[0.0 -1.0] candidates)
;; vectorised – same logic, kept scalar for clarity
(reduce (fn [[best-s best-g] s]
(let [[l r] (split-with #(<= (aget col (aget idxs %)) s) (range n))
vl (variance-vec y (int-array l) (count l))
vr (variance-vec y (int-array r) (count r))
gain (- (variance-vec y idxs n)
(+ (* (/ (count l) n) vl)
(* (/ (count r) n) vr)))]
(if (> gain best-g) [s gain] [best-s best-g])))
[0.0 -1.0] candidates))))
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
;;;; 5. Tree builder (CART – regression)
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
(defn- build-tree-node
[^floats X ^floats y ^ints idxs ^int n
depth max-depth min-samples-split min-samples-leaf
^Random rng cat-idxs]
(if (or (>= depth max-depth)
(<= n min-samples-split)
(zero? (variance-vec y idxs n)))
{:leaf? 1 :pred (/ (areduce idxs i s 0.0
(+ s (aget y (aget idxs i))))
n)
:samples n}
(let [f (rand-int (alength X) rng)
col (aget X f)
[s gain] (best-num-split col y idxs n rng)]
(if (< gain 1e-7)
{:leaf? 1 :pred (/ (areduce idxs i s 0.0
(+ s (aget y (aget idxs i))))
n)
:samples n}
(let [[l r] (split-with #(<= (aget X f (aget idxs %)) s) (range n))
left (int-array l)
right (int-array r)]
{:feature f :threshold s
:left (build-tree-node X y left (alength left) (inc depth)
max-depth min-samples-split min-samples-leaf
rng cat-idxs)
:right (build-tree-node X y right (alength right) (inc depth)
max-depth min-samples-split min-samples-leaf
rng cat-idxs)
:samples n})))))
(defn- flatten-tree
[root]
(let [nodes (atom [])
idx (atom 0)]
(letfn [(walk [node]
(let [i @idx]
(swap! idx inc)
(swap! nodes conj node)
(when-not (:leaf? node)
(let [l (walk (:left node))
r (walk (:right node))]
(swap! nodes assoc-in [i :left] l)
(swap! nodes assoc-in [i :right] r)))
i))]
(walk root)
(let [ns @nodes, c (count ns)]
(InternalTree.
(int-array (map #(or (:feature %) -1) ns))
(double-array (map #(or (:threshold %) 0.0) ns))
(int-array (map #(or (:left %) -1) ns))
(int-array (map #(or (:right %) -1) ns))
(int-array (map #(if (:leaf? %) 1 0) ns))
(double-array (map #(or (:pred %) 0.0) ns))
c)))))
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
;;;; 6. Vectorised Random-Forest batch prediction
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
(defn- predict-1-batched
"Vectorised evaluation of ONE tree on a batch of rows."
[^InternalTree t ^"[[F" X]
(let [n-rows (alength X)
preds (float-array n-rows)]
(dotimes [r n-rows]
(let [^floats row (aget X r)]
(aset preds r (float (tree-predict t row)))))
preds))
(defn- add-into!
[^floats acc ^floats inc ^int n]
(let [vlen (.length F_SPECIES)]
(if (< n MIN-VECTOR-SIZE)
;; scalar
(dotimes [i n]
(aset acc i (+ (aget acc i) (aget inc i))))
;; vectorised
(loop [i 0]
(when (< i (- n vlen))
(.intoArray (.add (.fromArray F_SPECIES inc i)
(.fromArray F_SPECIES acc i))
acc i)
(recur (+ i vlen))))
;; tail
(dotimes [i (rem n vlen)]
(let [idx (+ (- n (rem n vlen)) i)]
(aset acc idx (+ (aget acc idx) (aget inc idx))))))))
(defn predict-batch
"Vectorised Random-Forest batch prediction."
[^RandomForest rf samples]
(let [n-rows (count samples)
X-batch (into-array (map float-array samples))
acc (float-array n-rows)
trees (.trees rf)]
(doseq [^InternalTree t trees]
(let [tree-preds (predict-1-batched t X-batch)]
(add-into! acc tree-preds n-rows)))
(let [inv-n (/ 1.0 (count trees))]
(dotimes [i n-rows]
(aset acc i (* inv-n (aget acc i)))))
acc))
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
;;;; 7. Public API
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
(defrecord RandomForest
[trees feature-names target-name categorical-encoders])
(defn fit
[X y opts]
(let [n-estimators (or (:n-estimators opts) 100)
max-depth (or (:max-depth opts) 5)
min-samples-split (or (:min-samples-split opts) 2)
min-samples-leaf (or (:min-samples-leaf opts) 1)
rng-seed (or (:random-state opts) (System/currentTimeMillis))
rng (Random. rng-seed)
cat-idxs (or (:categorical-idxs opts) #{})
X-encoded (if (seq cat-idxs)
(apply-encoding X (encode-categorical X cat-idxs))
X)
X-arr (into-array (map float-array X-encoded))
y-arr (float-array y)
n-samples (alength y-arr)
n-features (count (aget X-arr 0))]
(let [trees
(mapv (fn [_]
(let [idxs (int-array n-samples)]
(dotimes [i n-samples] (aset idxs i i))
(shuffle-into! idxs rng)
(let [root (build-tree-node X-arr y-arr idxs n-samples
0 max-depth
min-samples-split
min-samples-leaf
rng cat-idxs)]
(flatten-tree root))))
(range n-estimators))]
(RandomForest. trees [] "target"
(when (seq cat-idxs)
(encode-categorical X cat-idxs))))))
(defn predict
[^RandomForest rf sample]
(let [^doubles x (double-array sample)]
(/ (areduce (.trees rf) i sum 0.0
(+ sum (tree-predict (aget (.trees rf) i) x)))
(count (.trees rf)))))
(defn score
[rf X y]
(let [preds (predict-batch rf X)
y-arr (float-array y)
sse (areduce preds i s 0.0
(+ s (Math/pow (- (aget preds i) (aget y-arr i)) 2.0)))
mean (/ (areduce y-arr i s 0.0 (+ s (aget y-arr i))) (alength y-arr))
sst (areduce y-arr i s 0.0
(+ s (Math/pow (- (aget y-arr i) mean) 2.0)))]
(- 1.0 (/ sse sst))))
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
;;;; 8. REPL demo
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
(comment
(def iris (with-open [r (io/reader "iris.csv")]
(doall (map #(str/split % #",") (line-seq r)))))
(def X (mapv #(mapv Float/parseFloat %) (map #(take 4 %) (rest iris))))
(def y (mapv #(case (peek %) "setosa" 0.0 "versicolor" 1.0 "virginica" 2.0)
(rest iris)))
(def rf (fit X y {:n-estimators 50 :max-depth 4 :random-state 42}))
(score rf X y) ;; ~0.97
(predict rf [5.1 3.5 1.4 0.2]) ;; ≈ 0.0
)
(require '[random-forest :as rf])
(def model (rf/fit X y {:n-estimators 100
:max-depth 8
:random-state 123
:categorical-idxs #{2 3}}))
(rf/predict model [1.2 5.3 "high" "yes"])
(rf/predict-batch model test-X)
(rf/score model test-X test-y)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment