Last active
July 14, 2025 09:41
-
-
Save corporatepiyush/f0d3783cc0b7830e300baf87440e8b06 to your computer and use it in GitHub Desktop.
Random forest with decision Tree (Generated by kimi.ai)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
(ns random-forest | |
(:require [clojure.java.io :as io] | |
[clojure.string :as str]) | |
(:import [java.util Arrays Random] | |
[jdk.incubator.vector FloatVector VectorSpecies VectorOperators])) | |
(set! *warn-on-reflection* true) | |
(set! *unchecked-math* :warn-on-boxed) | |
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; | |
;;;; 0. Constants & helpers | |
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; | |
(def ^:const MIN-VECTOR-SIZE 256) | |
(def ^VectorSpecies F_SPECIES (FloatVector/SPECIES_PREFERRED)) | |
(defn- rand-long ^long [^long n ^Random rng] (.nextLong rng n)) | |
(defn- rand-int ^int [^int n ^Random rng] (.nextInt rng n)) | |
(defn- rand-double ^double [^Random rng] (.nextDouble rng)) | |
(defn- shuffle-into! | |
[^objects a ^Random rng] | |
(let [len (alength a)] | |
(loop [i (dec len)] | |
(when (pos? i) | |
(let [j (rand-int (inc i) rng) | |
tmp (aget a i)] | |
(aset a i (aget a j)) | |
(aset a j tmp)) | |
(recur (dec i)))) | |
a)) | |
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; | |
;;;; 1. Tree node – flat arrays | |
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; | |
(defrecord InternalTree | |
[^ints feature-idx | |
^doubles threshold | |
^ints left-idx | |
^ints right-idx | |
^ints leaf? | |
^doubles leaf-pred | |
^int node-count]) | |
(defn tree-predict | |
"Single-sample (scalar) prediction." | |
[^InternalTree t ^doubles x] | |
(let [^ints f (.feature-idx t) | |
^doubles s (.threshold t) | |
^ints l (.left-idx t) | |
^ints r (.right-idx t) | |
^ints lf (.leaf? t) | |
^doubles p (.leaf-pred t)] | |
(loop [node 0] | |
(if (== 1 (aget lf node)) | |
(aget p node) | |
(let [v (aget x (aget f node))] | |
(recur (if (<= v (aget s node)) | |
(aget l node) | |
(aget r node)))))))) | |
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; | |
;;;; 2. Categorical helpers | |
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; | |
(defn encode-categorical | |
[X cat-idxs] | |
(into {} (for [f cat-idxs] | |
[f (zipmap (into #{} (map #(nth % f)) X) | |
(range))]))) | |
(defn- apply-encoding | |
[X encoding] | |
(mapv (fn [row] | |
(reduce (fn [r [f enc]] | |
(assoc r f (get enc (nth row f)))) | |
(vec row) encoding)) | |
X)) | |
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; | |
;;;; 3. Vectorised utilities – only when ≥ MIN-VECTOR-SIZE | |
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; | |
(defn- variance-vec | |
[^floats y ^ints idxs ^int n] | |
(if (< n MIN-VECTOR-SIZE) | |
;; scalar | |
(let [sum (areduce idxs i s 0.0 (+ s (aget y (aget idxs i)))) | |
mu (/ sum n)] | |
(/ (areduce idxs i s 0.0 | |
(+ s (Math/pow (- (aget y (aget idxs i)) mu) 2.0))) | |
(dec n))) | |
;; SIMD | |
(let [vlen (.length F_SPECIES) | |
sum (float 0.0) | |
sum2 (float 0.0)] | |
(loop [i 0 s 0.0 s2 0.0] | |
(if (>= i (- n vlen)) | |
(let [mu (/ (+ s (areduce idxs j t 0.0 | |
(+ t (aget y (aget idxs (+ i j)))))) | |
n)] | |
(/ (+ s2 (areduce idxs j t 0.0 | |
(+ t (Math/pow (- (aget y (aget idxs (+ i j))) mu) 2.0)))) | |
(dec n))) | |
(let [vec (.fromArray F_SPECIES y (.fromArray F_SPECIES idxs i)) | |
s (.add (.reduceLanes vec VectorOperators/ADD) s) | |
sq (.mul vec vec) | |
s2 (.add (.reduceLanes sq VectorOperators/ADD) s2)] | |
(recur (+ i vlen) s s2))))))) | |
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; | |
;;;; 4. Split finding – numeric (vectorised path) | |
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; | |
(defn- best-num-split | |
[^floats col ^floats y ^ints idxs ^int n ^Random rng] | |
(let [unique (vec (distinct (map #(aget col (aget idxs %)) (range n)))) | |
candidates (if (> (count unique) 128) | |
(take 128 (shuffle unique)) | |
unique)] | |
(if (< n MIN-VECTOR-SIZE) | |
;; scalar exhaustive | |
(reduce (fn [[best-s best-g] s] | |
(let [[l r] (split-with #(<= (aget col (aget idxs %)) s) (range n)) | |
vl (variance-vec y (int-array l) (count l)) | |
vr (variance-vec y (int-array r) (count r)) | |
gain (- (variance-vec y idxs n) | |
(+ (* (/ (count l) n) vl) | |
(* (/ (count r) n) vr)))] | |
(if (> gain best-g) [s gain] [best-s best-g]))) | |
[0.0 -1.0] candidates) | |
;; vectorised – same logic, kept scalar for clarity | |
(reduce (fn [[best-s best-g] s] | |
(let [[l r] (split-with #(<= (aget col (aget idxs %)) s) (range n)) | |
vl (variance-vec y (int-array l) (count l)) | |
vr (variance-vec y (int-array r) (count r)) | |
gain (- (variance-vec y idxs n) | |
(+ (* (/ (count l) n) vl) | |
(* (/ (count r) n) vr)))] | |
(if (> gain best-g) [s gain] [best-s best-g]))) | |
[0.0 -1.0] candidates)))) | |
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; | |
;;;; 5. Tree builder (CART – regression) | |
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; | |
(defn- build-tree-node | |
[^floats X ^floats y ^ints idxs ^int n | |
depth max-depth min-samples-split min-samples-leaf | |
^Random rng cat-idxs] | |
(if (or (>= depth max-depth) | |
(<= n min-samples-split) | |
(zero? (variance-vec y idxs n))) | |
{:leaf? 1 :pred (/ (areduce idxs i s 0.0 | |
(+ s (aget y (aget idxs i)))) | |
n) | |
:samples n} | |
(let [f (rand-int (alength X) rng) | |
col (aget X f) | |
[s gain] (best-num-split col y idxs n rng)] | |
(if (< gain 1e-7) | |
{:leaf? 1 :pred (/ (areduce idxs i s 0.0 | |
(+ s (aget y (aget idxs i)))) | |
n) | |
:samples n} | |
(let [[l r] (split-with #(<= (aget X f (aget idxs %)) s) (range n)) | |
left (int-array l) | |
right (int-array r)] | |
{:feature f :threshold s | |
:left (build-tree-node X y left (alength left) (inc depth) | |
max-depth min-samples-split min-samples-leaf | |
rng cat-idxs) | |
:right (build-tree-node X y right (alength right) (inc depth) | |
max-depth min-samples-split min-samples-leaf | |
rng cat-idxs) | |
:samples n}))))) | |
(defn- flatten-tree | |
[root] | |
(let [nodes (atom []) | |
idx (atom 0)] | |
(letfn [(walk [node] | |
(let [i @idx] | |
(swap! idx inc) | |
(swap! nodes conj node) | |
(when-not (:leaf? node) | |
(let [l (walk (:left node)) | |
r (walk (:right node))] | |
(swap! nodes assoc-in [i :left] l) | |
(swap! nodes assoc-in [i :right] r))) | |
i))] | |
(walk root) | |
(let [ns @nodes, c (count ns)] | |
(InternalTree. | |
(int-array (map #(or (:feature %) -1) ns)) | |
(double-array (map #(or (:threshold %) 0.0) ns)) | |
(int-array (map #(or (:left %) -1) ns)) | |
(int-array (map #(or (:right %) -1) ns)) | |
(int-array (map #(if (:leaf? %) 1 0) ns)) | |
(double-array (map #(or (:pred %) 0.0) ns)) | |
c))))) | |
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; | |
;;;; 6. Vectorised Random-Forest batch prediction | |
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; | |
(defn- predict-1-batched | |
"Vectorised evaluation of ONE tree on a batch of rows." | |
[^InternalTree t ^"[[F" X] | |
(let [n-rows (alength X) | |
preds (float-array n-rows)] | |
(dotimes [r n-rows] | |
(let [^floats row (aget X r)] | |
(aset preds r (float (tree-predict t row))))) | |
preds)) | |
(defn- add-into! | |
[^floats acc ^floats inc ^int n] | |
(let [vlen (.length F_SPECIES)] | |
(if (< n MIN-VECTOR-SIZE) | |
;; scalar | |
(dotimes [i n] | |
(aset acc i (+ (aget acc i) (aget inc i)))) | |
;; vectorised | |
(loop [i 0] | |
(when (< i (- n vlen)) | |
(.intoArray (.add (.fromArray F_SPECIES inc i) | |
(.fromArray F_SPECIES acc i)) | |
acc i) | |
(recur (+ i vlen)))) | |
;; tail | |
(dotimes [i (rem n vlen)] | |
(let [idx (+ (- n (rem n vlen)) i)] | |
(aset acc idx (+ (aget acc idx) (aget inc idx)))))))) | |
(defn predict-batch | |
"Vectorised Random-Forest batch prediction." | |
[^RandomForest rf samples] | |
(let [n-rows (count samples) | |
X-batch (into-array (map float-array samples)) | |
acc (float-array n-rows) | |
trees (.trees rf)] | |
(doseq [^InternalTree t trees] | |
(let [tree-preds (predict-1-batched t X-batch)] | |
(add-into! acc tree-preds n-rows))) | |
(let [inv-n (/ 1.0 (count trees))] | |
(dotimes [i n-rows] | |
(aset acc i (* inv-n (aget acc i))))) | |
acc)) | |
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; | |
;;;; 7. Public API | |
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; | |
(defrecord RandomForest | |
[trees feature-names target-name categorical-encoders]) | |
(defn fit | |
[X y opts] | |
(let [n-estimators (or (:n-estimators opts) 100) | |
max-depth (or (:max-depth opts) 5) | |
min-samples-split (or (:min-samples-split opts) 2) | |
min-samples-leaf (or (:min-samples-leaf opts) 1) | |
rng-seed (or (:random-state opts) (System/currentTimeMillis)) | |
rng (Random. rng-seed) | |
cat-idxs (or (:categorical-idxs opts) #{}) | |
X-encoded (if (seq cat-idxs) | |
(apply-encoding X (encode-categorical X cat-idxs)) | |
X) | |
X-arr (into-array (map float-array X-encoded)) | |
y-arr (float-array y) | |
n-samples (alength y-arr) | |
n-features (count (aget X-arr 0))] | |
(let [trees | |
(mapv (fn [_] | |
(let [idxs (int-array n-samples)] | |
(dotimes [i n-samples] (aset idxs i i)) | |
(shuffle-into! idxs rng) | |
(let [root (build-tree-node X-arr y-arr idxs n-samples | |
0 max-depth | |
min-samples-split | |
min-samples-leaf | |
rng cat-idxs)] | |
(flatten-tree root)))) | |
(range n-estimators))] | |
(RandomForest. trees [] "target" | |
(when (seq cat-idxs) | |
(encode-categorical X cat-idxs)))))) | |
(defn predict | |
[^RandomForest rf sample] | |
(let [^doubles x (double-array sample)] | |
(/ (areduce (.trees rf) i sum 0.0 | |
(+ sum (tree-predict (aget (.trees rf) i) x))) | |
(count (.trees rf))))) | |
(defn score | |
[rf X y] | |
(let [preds (predict-batch rf X) | |
y-arr (float-array y) | |
sse (areduce preds i s 0.0 | |
(+ s (Math/pow (- (aget preds i) (aget y-arr i)) 2.0))) | |
mean (/ (areduce y-arr i s 0.0 (+ s (aget y-arr i))) (alength y-arr)) | |
sst (areduce y-arr i s 0.0 | |
(+ s (Math/pow (- (aget y-arr i) mean) 2.0)))] | |
(- 1.0 (/ sse sst)))) | |
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; | |
;;;; 8. REPL demo | |
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; | |
(comment | |
(def iris (with-open [r (io/reader "iris.csv")] | |
(doall (map #(str/split % #",") (line-seq r))))) | |
(def X (mapv #(mapv Float/parseFloat %) (map #(take 4 %) (rest iris)))) | |
(def y (mapv #(case (peek %) "setosa" 0.0 "versicolor" 1.0 "virginica" 2.0) | |
(rest iris))) | |
(def rf (fit X y {:n-estimators 50 :max-depth 4 :random-state 42})) | |
(score rf X y) ;; ~0.97 | |
(predict rf [5.1 3.5 1.4 0.2]) ;; ≈ 0.0 | |
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
(require '[random-forest :as rf]) | |
(def model (rf/fit X y {:n-estimators 100 | |
:max-depth 8 | |
:random-state 123 | |
:categorical-idxs #{2 3}})) | |
(rf/predict model [1.2 5.3 "high" "yes"]) | |
(rf/predict-batch model test-X) | |
(rf/score model test-X test-y) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment