我有关于 XGBoost 的问题。
你知道如何知道XGBoost中创建的树的数量吗? 与 RandomForest 不同,RandomForest 是由模型制作者决定生成多少棵树,而 XGBoost 基本上会持续创建树,直到损失函数达到一定的数字。所以我想知道这个。
谢谢你。
这有点歪,但我目前正在做的是
dump
-ing 模型(XGBoost 生成一个列表,其中每个元素都是单个树的字符串表示),然后计算列表中有多少个元素:
# clf is a XGBoost model fitted using the sklearn API
dump_list = clf.get_booster().get_dump()
num_trees = len(dump_list)
提交答案可能为时已晚。但尽管如此,我最近做了以下事情:
加载模型并将模型保存为 json 文件。然后加载 json 文件并从 json 文件中打印 num_tree 详细信息。
您可以检查以下代码片段: 运行它(将代码片段保存到文件 xgb_tree_count.py 后): python xgb_tree_count.py 模型文件路径
import sys
import json
import xgboost as xgb
if len(sys.argv) < 2:
print(f'Usage: {sys.argv[0]} <model-file>')
exit(1)
loaded_model = xgb.Booster()
loaded_model.load_model(sys.argv[1])
loaded_model.save_model('/tmp/a_model.json')
with open('/tmp/a_model.json', 'r') as fp:
jsonrepr = json.load(fp)
print(jsonrepr['learner']['gradient_booster']['model']['gbtree_model_param']['num_trees'])
在java中,似乎没有直接的方法来做到这一点。但是,您可以使用模型转储的结果来获取实际的树木数量。使用经过训练的
Booster
:
int numberOfTrees = booster.getModelDump("", false, "text").length;