我希望能够查看传递给keras的RandomForestModel的超参数.我认为这应该是可能的与model.get_config()
.
这是在我的RandomForestWrapper类中创建模型的函数:
def add_new_model(self, model_name, params):
self.train_test_split()
model = tfdf.keras.RandomForestModel(
random_seed=params["random_seed"],
num_trees=params["num_trees"],
categorical_algorithm=params["categorical_algorithm"],
compute_oob_performances=params["compute_oob_performances"],
growing_strategy=params["growing_strategy"],
honest=params["honest"],
max_depth=params["max_depth"],
max_num_nodes=params["max_num_nodes"]
)
print(model.get_config())
self.models.update({model_name: model})
print(f"{model_name} added")
示例参数:
params_v2 = {
"random_seed": 123456,
"num_trees": 1000,
"categorical_algorithm": "CART",
"compute_oob_performances": True,
"growing_strategy": "LOCAL",
"honest": True,
"max_depth": 8,
"max_num_nodes": None
}
然后实例化类并训练模型:
rf_models = RF(data, obs_col="obs", class_col="cell_type")
rf_models.add_new_model("model_2", params_v2)
rf_models.train_model("model_2", verbose=False, metrics=["Accuracy"])
model = rf_models.models["model_2"]
model.get_config()
##
{}
在模型总结中,我可以看到参数是可以接受的.