Last active
February 26, 2021 04:51
-
-
Save KMarkert/15469eb59efc15ef655b2b3f51e9db01 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import pandas as pd | |
def sklearn_tree_to_ee_string(estimator, feature_names): | |
# extract out the information need to build the tree string | |
n_nodes = estimator.tree_.node_count | |
children_left = estimator.tree_.children_left | |
children_right = estimator.tree_.children_right | |
feature_idx = estimator.tree_.feature | |
impurities = estimator.tree_.impurity | |
n_samples = estimator.tree_.n_node_samples | |
thresholds = estimator.tree_.threshold | |
features = [feature_names[i] for i in feature_idx] | |
raw_vals = estimator.tree_.value | |
if raw_vals.ndim == 3: | |
# take argmax along class axis from values | |
values = np.squeeze(raw_vals.argmax(axis=-1)) | |
elif raw_vals.ndim == 2: | |
# take values and drop un needed axis | |
values = np.squeeze(raw_vals) | |
else: | |
raise RuntimeError("could not understand estimator type and parse out the values") | |
# use iterative pre-order search to extract node depth and leaf information | |
node_ids = np.zeros(shape=n_nodes, dtype=np.int64) | |
node_depth = np.zeros(shape=n_nodes, dtype=np.int64) | |
is_leaves = np.zeros(shape=n_nodes, dtype=bool) | |
stack = [(0, -1)] # seed is the root node id and its parent depth | |
while len(stack) > 0: | |
node_id, parent_depth = stack.pop() | |
node_depth[node_id] = parent_depth + 1 | |
node_ids[node_id] = node_id | |
# If we have a test node | |
if children_left[node_id] != children_right[node_id]: | |
stack.append((children_left[node_id], parent_depth + 1)) | |
stack.append((children_right[node_id], parent_depth + 1)) | |
else: | |
is_leaves[node_id] = True | |
# create a table of the initial structure | |
# each row is a node or leaf | |
df = pd.DataFrame( | |
{ | |
"node_id": node_ids, | |
"node_depth": node_depth, | |
"is_leaf": is_leaves, | |
"children_left": children_left, | |
"children_right": children_right, | |
"value": values, | |
"criterion": impurities, | |
"n_samples": n_samples, | |
"threshold": thresholds, | |
"feature_name": features, | |
"sign": ["<="] * n_nodes, | |
} | |
) | |
# the table representation does not have lef vs right node structure | |
# so we need to add in right nodes in the correct location | |
# we do this by first calculating which nodes are right and then insert them at the correct index | |
# get a dict of right node rows and assign key based on index where to insert | |
inserts = {} | |
for row in df.itertuples(): | |
child_r = row.children_right | |
if child_r > row.Index: | |
ordered_row = np.array(row) | |
ordered_row[-1] = ">" | |
inserts[child_r] = ordered_row[1:] # drop index value | |
# sort the inserts as to keep track of the additive indexing | |
inserts_sorted = {k: inserts[k] for k in sorted(inserts.keys())} | |
# loop through the row inserts and add to table (array) | |
table_values = df.values | |
for i, k in enumerate(inserts_sorted.keys()): | |
table_values = np.insert(table_values, (k + i), inserts_sorted[k], axis=0) | |
# make the ordered table array into a dataframe | |
# note: df is dtype "object", need to cast later on | |
ordered_df = pd.DataFrame(table_values, columns=df.columns) | |
max_depth = np.max(ordered_df.node_depth.astype(int)) | |
tree_str = f"1) root {n_samples[0]} 9999 9999 ({impurities.sum()})\n" | |
previous_depth = -1 | |
cnts = [] | |
# loop through the nodes and calculate the node number and values per node | |
for row in ordered_df.itertuples(): | |
node_depth = int(row.node_depth) | |
left = int(row.children_left) | |
right = int(row.children_right) | |
if left != right: | |
if row.Index == 0: | |
cnt = 2 | |
elif previous_depth > node_depth: | |
depths = ordered_df.node_depth.values[: row.Index] | |
idx = np.where(depths == node_depth)[0][-1] | |
# cnt = (cnts[row.Index-1] // 2) + 1 | |
cnt = cnts[idx] + 1 | |
elif previous_depth < node_depth: | |
cnt = cnts[row.Index - 1] * 2 | |
elif previous_depth == node_depth: | |
cnt = cnts[row.Index - 1] + 1 | |
if node_depth == (max_depth - 1): | |
value = float(ordered_df.iloc[row.Index + 1].value) | |
samps = int(ordered_df.iloc[row.Index + 1].n_samples) | |
criterion = float(ordered_df.iloc[row.Index + 1].criterion) | |
tail = " *\n" | |
else: | |
if ( | |
(bool(ordered_df.loc[ordered_df.node_id == left].iloc[0].is_leaf)) | |
and ( | |
bool( | |
int(row.Index) | |
< int(ordered_df.loc[ordered_df.node_id == left].index[0]) | |
) | |
) | |
and (str(row.sign) == "<=") | |
): | |
rowx = ordered_df.loc[ordered_df.node_id == left].iloc[0] | |
tail = " *\n" | |
value = float(rowx.value) | |
samps = int(rowx.n_samples) | |
criterion = float(rowx.criterion) | |
elif ( | |
(bool(ordered_df.loc[ordered_df.node_id == right].iloc[0].is_leaf)) | |
and ( | |
bool( | |
int(row.Index) | |
< int(ordered_df.loc[ordered_df.node_id == right].index[0]) | |
) | |
) | |
and (str(row.sign) == ">") | |
): | |
rowx = ordered_df.loc[ordered_df.node_id == right].iloc[0] | |
tail = " *\n" | |
value = float(rowx.value) | |
samps = int(rowx.n_samples) | |
criterion = float(rowx.criterion) | |
else: | |
value = float(row.value) | |
samps = int(row.n_samples) | |
criterion = float(row.criterion) | |
tail = "\n" | |
# extract out the information needed in each line | |
spacing = (node_depth + 1) * " " # for pretty printing | |
fname = str(row.feature_name) # name of the feature (i.e. band name) | |
tresh = float(row.threshold) # threshold | |
sign = str(row.sign) | |
tree_str += f"{spacing}{cnt}) {fname} {sign} {tresh:.6f} {samps} {criterion:.4f} {value:.6f}{tail}" | |
previous_depth = node_depth | |
cnts.append(cnt) | |
return tree_str | |
if __name__ == "__main__": | |
# do your model training here | |
# model is an sklearn RandomForestClassifier or RandomForestRegressor | |
estimators = model.estimators_ | |
trees = [] | |
for i, estimator in enumerate(estimators): | |
string = sklearn_tree_to_ee_string(estimator, features) | |
trees.append(trees) | |
# save tree strings to text files of GCS | |
# or use directly with ee |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment