Last active
August 17, 2021 11:20
-
-
Save zafercavdar/c8be4ce3758fdca4fb60a1672fb4f451 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sklearn.tree import export_graphviz | |
export_graphviz(best_model, | |
out_file='tree.dot', | |
feature_names = feature_pd.columns, | |
class_names = le.classes_, | |
rounded = True, | |
proportion = False, | |
precision = 2, | |
filled = True) | |
with open("tree.dot", "r") as f: | |
data = f.read() | |
import re | |
from copy import deepcopy | |
pattern = re.compile(r"\\nclass = ([\s\w]+)\"") | |
pattern2 = re.compile(r"(\")([\w\d\s\-<=|.]+\\n)(entropy)") | |
class Node: | |
def __init__(self, _id: str, _class: str, def_str): | |
self._id = _id | |
self.children = [] | |
self._class = _class | |
self.def_str = def_str | |
def add_child(self, subnode): | |
self.children.append(subnode) | |
@property | |
def is_leaf(self): | |
return len(self.children) == 0 | |
def prune_if_needed(self): | |
child_classes = set([child._class for child in self.children]) | |
if len(child_classes) == 1 and all([child.is_leaf for child in self.children]): | |
print(f"Pruning {self._id}") | |
killed_children = deepcopy(self.children) | |
self.children = [] | |
self.def_str = pattern2.sub("\\1\\3", self.def_str) | |
# print(self.def_str) | |
return killed_children | |
return [] | |
def prune(graph_data): | |
nodes = {} | |
lines = graph_data.split("\n") | |
for line in lines: | |
if " -> " in line: | |
_from, _to = line.split(" -> ") | |
_to = _to.split(" ")[0] | |
nodes[_from].add_child(nodes[_to]) | |
elif any([line.startswith(x) for x in ["digraph", "node", "edge", "}"]]): | |
continue | |
else: | |
parts = line.split() | |
node_id = parts[0] | |
rest = " ".join(parts[1:]) | |
_class = pattern.findall(rest)[0] | |
new_node = Node(node_id, _class, line) | |
nodes[node_id] = new_node | |
deleted_nodes = [] | |
for _ in range(10): | |
for node_id, node in nodes.items(): | |
deleted_nodes.extend(node.prune_if_needed()) | |
deleted_node_ids = [node._id for node in deleted_nodes] | |
new_lines = [] | |
for line in lines: | |
if " -> " in line: | |
_from, _to = line.split(" -> ") | |
_to = _to.split(" ")[0] | |
if _to not in deleted_node_ids: | |
new_lines.append(line) | |
elif any([line.startswith(x) for x in ["digraph", "node", "edge", "}"]]): | |
new_lines.append(line) | |
else: | |
parts = line.split() | |
node_id = parts[0] | |
if node_id not in deleted_node_ids: | |
new_lines.append(" ".join([nodes[node_id]._id, nodes[node_id].def_str])) | |
return "\n".join(new_lines) | |
with open("tree2.dot", "w") as f: | |
# this line is optional | |
updated_data = data.replace("<= 0.5", "is AOM-associated").replace("True", "ĞĞĞ").replace("False", "True").replace("ĞĞĞ", "False") | |
f.write(prune(updated_data)) | |
!dot -Tpng tree.dot -o tree.png -Gdpi=300 | |
from IPython.display import Image | |
Image(filename = 'tree.png') | |
!dot -Tpng tree2.dot -o tree.png -Gdpi=300 | |
from IPython.display import Image | |
Image(filename = 'tree.png') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment