我用Python编写了一个决策树类,它使用Node类作为树节点,如下所示:
class Node:
'''
Helper class which implements a single tree node.
'''
def __init__(self, feature=None, threshold=None, data_left=None, data_right=None, gain=None, value=None):
self.feature = feature
self.threshold = threshold
self.data_left = data_left
self.data_right = data_right
self.gain = gain
self.value = value
现在,我想为这棵树编写一个打印方法,打印这棵树每个节点的增益、特征名称和阈值,并打印叶节点的值,该值代表输入样本的最终标签。
我可以使用哪些库以图像格式打印此输出训练树?我该如何使用它们?
如果我不想以图像格式打印这棵树,我可以使用哪种算法以可读的方式打印我之前提到的信息?
我已经写了我的打印方法如下!但它根本不可读:
def print_tree(self,node,depth=0):
if node is None:
return
prefix = " " * depth
# If the node is a leaf node, print its value
if node.value is not None:
print(f"{prefix}Value: {node.value}")
else:
# Print the feature and threshold for the split at this node
print(f"{prefix}Feature: {node.feature}, Threshold: {node.threshold}")
# Recursively print the left and right subtrees
print(f"{prefix}--> Left:")
self.print_tree(node.data_left, depth + 1)
print(f"{prefix}--> Right:")
self.print_tree(node.data_right, depth + 1)
我使用 Graphviz 及其 DOT 语言来绘制图表 https://graphviz.org/doc/info/lang.html .
您可以使用命令将点文件转换为 pdf
dot -Tpdf dot_file.dot -o out_name.pdf