Created
February 20, 2019 13:21
-
-
Save rust-play/a57cdaaa637540caaee506854adc5606 to your computer and use it in GitHub Desktop.
Code shared from the Rust Playground
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#[macro_use] | |
extern crate error_chain; // 0.12.0 | |
extern crate num_traits; // 0.2.6 | |
use num_traits::AsPrimitive; | |
use std::fmt::Debug; | |
error_chain! { | |
types { TractError, TractErrorKind, TractResultExt, TractResult; } | |
foreign_links {} | |
errors { TFString {} } | |
} | |
#[derive(Debug, Clone, Copy, PartialEq, Eq)] | |
pub enum Cmp { | |
LessEqual, | |
Less, | |
GreaterEqual, | |
Greater, | |
Equal, | |
NotEqual, | |
} | |
impl Cmp { | |
pub fn compare(&self, x: f32, y: f32) -> bool { | |
match *self { | |
Cmp::LessEqual => x <= y, | |
Cmp::Less => x < y, | |
Cmp::GreaterEqual => x >= y, | |
Cmp::Greater => x > y, | |
Cmp::Equal => x == y, | |
Cmp::NotEqual => x != y, | |
} | |
} | |
} | |
#[derive(Copy, Clone, Debug)] | |
pub struct Branch { | |
pub cmp: Cmp, | |
pub feature_id: usize, | |
pub value: f32, | |
pub true_id: usize, | |
pub false_id: usize, | |
pub nan_is_true: bool, | |
} | |
impl Branch { | |
pub fn child_id(&self, feature: f32) -> usize { | |
let condition = if feature.is_nan() { | |
self.nan_is_true | |
} else { | |
self.cmp.compare(feature, self.value) | |
}; | |
if condition { | |
self.true_id | |
} else { | |
self.false_id | |
} | |
} | |
} | |
#[derive(Copy, Clone, Debug)] | |
pub enum Node<L> { | |
Branch(Branch), | |
Leaf(L), | |
} | |
#[derive(Clone, Debug)] | |
pub struct Tree<L> { | |
nodes: Vec<Node<L>>, | |
root_id: usize, | |
} | |
impl<L: Debug + Clone> Tree<L> { | |
pub fn eval_unchecked<X, T>(&self, x: X) -> TractResult<&L> | |
where | |
X: AsRef<[T]>, // not entirely correct (e.g. ndarray, strides etc) | |
T: AsPrimitive<f32>, | |
{ | |
let x = x.as_ref(); | |
let mut node_id = self.root_id; | |
loop { | |
let node = unsafe { self.nodes.get_unchecked(node_id) }; | |
match node { | |
Node::Branch(ref b) => { | |
let feature = unsafe { *x.get_unchecked(b.feature_id) }; | |
node_id = b.child_id(feature.as_()); | |
} | |
Node::Leaf(ref leaf) => { | |
return Ok(&leaf); | |
} | |
} | |
} | |
} | |
fn branches(&self) -> impl Iterator<Item = &Branch> { | |
self.nodes.iter().filter_map(|node| match node { | |
Node::Branch(ref branch) => Some(branch), | |
_ => None, | |
}) | |
} | |
fn leaves(&self) -> impl Iterator<Item = &L> { | |
self.nodes.iter().filter_map(|node| match node { | |
Node::Leaf(ref leaf) => Some(leaf), | |
_ => None, | |
}) | |
} | |
pub fn max_feature_id(&self) -> usize { | |
self.branches().map(|b| b.feature_id).max().unwrap_or(0) | |
} | |
pub fn from_nodes(nodes: &[Node<L>]) -> TractResult<Self> { | |
let len = nodes.len(); | |
ensure!(len > 0, "Invalid tree: expected non-zero node count"); | |
let mut max_feature_id = 0; | |
for node in nodes { | |
if let &Node::Branch(b) = node { | |
ensure!( | |
b.feature_id < len, | |
"Invalid node: {:?} (expected feature_id = {} < len = {})", | |
node, b.feature_id, len | |
); | |
max_feature_id = max_feature_id.max(b.feature_id); | |
ensure!( | |
b.true_id < len, | |
"Invalid node: {:?} (expected true_id = {} < len = {})", | |
node, b.true_id, len | |
); | |
ensure!( | |
b.false_id < len, | |
"Invalid node: {:?} (expected false_id = {} < len = {})", | |
node, b.false_id, len | |
); | |
} | |
} | |
Ok(Self { nodes: nodes.into(), root_id: 123 }) | |
} | |
} | |
pub struct TreeEnsemble<L> { | |
trees: Vec<Tree<L>>, | |
max_feature_id: usize, | |
} | |
impl<L: Debug + Clone> TreeEnsemble<L> { | |
pub fn from_trees(trees: &[Tree<L>]) -> Self { | |
let max_feature_id = trees.iter() | |
.map(Tree::max_feature_id).max().unwrap_or(0); | |
Self { trees: trees.into(), max_feature_id } | |
} | |
} | |
fn main() {} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment