Created
August 9, 2020 13:06
-
-
Save galbraun/b58b2937f130452ae81ea8d7d401bc0d to your computer and use it in GitHub Desktop.
Functions to extract for a xgboost forest for each tree and each leaf - the middle nodes that create the path to reach it.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def _root_to_leaf_route(df, stack, routes): | |
current_node = df.loc[df.Node == stack[-1]] | |
if current_node.Feature.values[0] == 'Leaf': | |
routes[current_node.Node.values[0]] = list(stack) | |
stack.pop() | |
return | |
stack.append(int(current_node.Yes.values[0].split('-')[1])) | |
_root_to_leaf_route(df, stack, routes) | |
stack.append(int(current_node.No.values[0].split('-')[1])) | |
_root_to_leaf_route(df, stack, routes) | |
stack.pop() | |
return | |
def extract_root_to_leaf_routes_for_forest(xgb): | |
routes_forest = {} | |
for i in range(len(xgb.get_booster().get_dump())): | |
routes_forest[i] = extract_root_to_leaf_routes_for_tree(xgb, i) | |
return routes_forest | |
def extract_root_to_leaf_routes_for_tree(xgb, tree_index): | |
df = xgb.get_booster().trees_to_dataframe() | |
df = df.loc[df.Tree == tree_index].set_index('ID') | |
routes = {} | |
stack = [0] | |
_root_to_leaf_route(df, stack, routes) | |
return routes |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment