我希望能够查看传递给keras的RandomForestModel的超参数.我认为这应该是可能的与model.get_config().

这是在我的RandomForestWrapper类中创建模型的函数:

def add_new_model(self, model_name, params):

    self.train_test_split()

        model = tfdf.keras.RandomForestModel(
            random_seed=params["random_seed"],
            num_trees=params["num_trees"],
            categorical_algorithm=params["categorical_algorithm"],
            compute_oob_performances=params["compute_oob_performances"],
            growing_strategy=params["growing_strategy"],
            honest=params["honest"],
            max_depth=params["max_depth"],
            max_num_nodes=params["max_num_nodes"]
           )

    print(model.get_config())
    self.models.update({model_name: model})
    print(f"{model_name} added")

示例参数:

params_v2 = {
    "random_seed": 123456,
    "num_trees": 1000,
    "categorical_algorithm": "CART",
    "compute_oob_performances": True,
    "growing_strategy": "LOCAL",
    "honest": True,
    "max_depth": 8,
    "max_num_nodes": None
}

然后实例化类并训练模型:

rf_models = RF(data, obs_col="obs", class_col="cell_type")
rf_models.add_new_model("model_2", params_v2)
rf_models.train_model("model_2", verbose=False, metrics=["Accuracy"])

model = rf_models.models["model_2"]
model.get_config()

##
{}

在模型总结中,我可以看到参数是可以接受的.

推荐答案

关于get_config(),请注意docs的状态:

返回模型的配置.

Config是一个Python字典(可序列化),其中包含

注意,get\u config()不保证返回

建议子类模型的开发人员重写此方法,

我想你能做的就是打电话给model.learner_params,了解你想要的细节:

import tensorflow_decision_forests as tfdf
import pprint

params_v2 = {
    "random_seed": 123456,
    "num_trees": 1000,
    "categorical_algorithm": "CART",
    "compute_oob_performances": True,
    "growing_strategy": "LOCAL",
    "honest": True,
    "max_depth": 8,
    "max_num_nodes": None
}

model = tfdf.keras.RandomForestModel().from_config(params_v2)
pprint.pprint(model.learner_params)
{'adapt_bootstrap_size_ratio_for_maximum_training_duration': False,
 'allow_na_conditions': False,
 'bootstrap_size_ratio': 1.0,
 'bootstrap_training_dataset': True,
 'categorical_algorithm': 'CART',
 'categorical_set_split_greedy_sampling': 0.1,
 'categorical_set_split_max_num_items': -1,
 'categorical_set_split_min_item_frequency': 1,
 'compute_oob_performances': True,
 'compute_oob_variable_importances': False,
 'growing_strategy': 'LOCAL',
 'honest': True,
 'honest_fixed_separation': False,
 'honest_ratio_leaf_examples': 0.5,
 'in_split_min_examples_check': True,
 'keep_non_leaf_label_distribution': True,
 'max_depth': 8,
 'max_num_nodes': None,
 'maximum_model_size_in_memory_in_bytes': -1.0,
 'maximum_training_duration_seconds': -1.0,
 'min_examples': 5,
 'missing_value_policy': 'GLOBAL_IMPUTATION',
 'num_candidate_attributes': 0,
 'num_candidate_attributes_ratio': -1.0,
 'num_oob_variable_importances_permutations': 1,
 'num_trees': 1000,
 'pure_serving_model': False,
 'random_seed': 123456,
 'sampling_with_replacement': True,
 'sorting_strategy': 'PRESORT',
 'sparse_oblique_normalization': None,
 'sparse_oblique_num_projections_exponent': None,
 'sparse_oblique_projection_density_factor': None,
 'sparse_oblique_weights': None,
 'split_axis': 'AXIS_ALIGNED',
 'uplift_min_examples_in_treatment': 5,
 'uplift_split_score': 'KULLBACK_LEIBLER',
 'winner_take_all': True}

Python相关问答推荐

如何自动抓取以下CSV

Polars比较了两个预设-有没有方法在第一次不匹配时立即失败

Python 约束无法解决n皇后之谜

无法通过python-jira访问jira工作日志(log)中的 comments

Mistral模型为不同的输入文本生成相同的嵌入

无法定位元素错误404

driver. find_element无法通过class_name找到元素'""

计算分布的标准差

isinstance()在使用dill.dump和dill.load后,对列表中包含的对象失败

不允许 Select 北极滚动?

如何删除重复的文字翻拍?

在我融化极点数据帧之后,我如何在不添加索引的情况下将其旋转回其原始形式?

合并相似列表

设置索引值每隔17行左右更改的索引

时间戳上的SOAP头签名无效

函数()参数';代码';必须是代码而不是字符串

如何让PYTHON上的Selify连接到现有的Firefox实例-我无法连接到Marionette端口

组颠倒大Pandas 数据帧

正则表达式反向查找

极地数据帧:ROLING_SUM向前看