-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathVisitor.py
49 lines (37 loc) · 1.24 KB
/
Visitor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from abc import ABC, abstractmethod
import Node
class Visitor(ABC):
# It follows the
# https://refactoring.guru/design-patterns/visitor
@abstractmethod
def visit_parent(self, parent: Node):
pass
@abstractmethod
def visit_leaf(self, leaf: Node):
pass
class FeatureImportance(Visitor):
def __init__(self):
self.occurrences = {}
def visit_parent(self, parent: Node) -> None:
k = parent.feature_index
if k in self.occurrences.keys():
self.occurrences[k] += 1
else:
self.occurrences[k] = 1
parent.left_child.accept_visitor(self)
parent.right_child.accept_visitor(self)
def visit_leaf(self, leaf: Node) -> None:
pass
class PrinterTree(Visitor):
def __init__(self):
self.depth = 0
def visit_parent(self, parent: Node) -> None:
print("\t"*self.depth+"parent, {}, {}".format(
parent.feature_index, parent.value))
self.depth += 1
parent.left_child.accept_visitor(self)
parent.right_child.accept_visitor(self)
self.depth -= 1
def visit_leaf(self, leaf: Node) -> None:
print("\t"*self.depth + "leaf, {}".format(
leaf.value_or_label_or_depth))