Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support stat with categorical split in graphviz dump. #11053

Merged
merged 3 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 20 additions & 19 deletions src/tree/tree_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -303,9 +303,8 @@ class TextGenerator : public TreeGenerator {
return result;
}

std::string SplitNodeImpl(
RegTree const& tree, int32_t nid, std::string const& template_str,
std::string cond, uint32_t depth) const {
std::string SplitNodeImpl(RegTree const& tree, bst_node_t nid, std::string const& template_str,
std::string cond, uint32_t depth) const {
auto split_index = tree[nid].SplitIndex();
std::string const result = SuperT::Match(
template_str,
Expand Down Expand Up @@ -345,18 +344,16 @@ class TextGenerator : public TreeGenerator {
return SplitNodeImpl(tree, nid, kNodeTemplate, ToStr(cond), depth);
}

std::string Categorical(RegTree const &tree, int32_t nid,
uint32_t depth) const override {
std::string Categorical(RegTree const& tree, bst_node_t nid, uint32_t depth) const override {
auto cats = GetSplitCategories(tree, nid);
std::string cats_str = PrintCatsAsSet(cats);
static std::string const kNodeTemplate =
"{tabs}{nid}:[{fname}:{cond}] yes={right},no={left},missing={missing}";
std::string const result =
SplitNodeImpl(tree, nid, kNodeTemplate, cats_str, depth);
std::string const result = SplitNodeImpl(tree, nid, kNodeTemplate, cats_str, depth);
return result;
}

std::string NodeStat(RegTree const& tree, int32_t nid) const override {
std::string NodeStat(RegTree const& tree, bst_node_t nid) const override {
static std::string const kStatTemplate = ",gain={loss_chg},cover={sum_hess}";
std::string const result = SuperT::Match(
kStatTemplate,
Expand Down Expand Up @@ -679,15 +676,12 @@ class GraphvizGenerator : public TreeGenerator {
std::string result;
if (this->with_stats_) {
CHECK(!tree.IsMultiTarget()) << MTNotImplemented();
result = SuperT::Match(
kNodeTemplate, {{"{nid}", std::to_string(nidx)},
{"{fname}", GetFeatureName(fmap_, split_index)},
{"{<}", has_less ? "<" : ""},
{"{cond}", has_less ? ToStr(cond) : ""},
{"{stat}", Match("\ncover={cover}\ngain={gain}",
{{"{cover}", std::to_string(tree.Stat(nidx).sum_hess)},
{"{gain}", std::to_string(tree.Stat(nidx).loss_chg)}})},
{"{params}", param_.condition_node_params}});
result = SuperT::Match(kNodeTemplate, {{"{nid}", std::to_string(nidx)},
{"{fname}", GetFeatureName(fmap_, split_index)},
{"{<}", has_less ? "<" : ""},
{"{cond}", has_less ? ToStr(cond) : ""},
{"{stat}", this->NodeStat(tree, nidx)},
{"{params}", param_.condition_node_params}});
} else {
result = SuperT::Match(kNodeTemplate, {{"{nid}", std::to_string(nidx)},
{"{fname}", GetFeatureName(fmap_, split_index)},
Expand All @@ -703,9 +697,15 @@ class GraphvizGenerator : public TreeGenerator {
return result;
};

std::string Categorical(RegTree const& tree, bst_node_t nidx, uint32_t) const override {
std::string NodeStat(RegTree const& tree, bst_node_t nidx) const override {
return Match("\ngain={gain}\ncover={cover}",
{{"{cover}", std::to_string(tree.Stat(nidx).sum_hess)},
{"{gain}", std::to_string(tree.Stat(nidx).loss_chg)}});
}

std::string Categorical(RegTree const& tree, bst_node_t nidx, uint32_t /*depth*/) const override {
static std::string const kLabelTemplate =
" {nid} [ label=\"{fname}:{cond}\" {params}]\n";
" {nid} [ label=\"{fname}:{cond}{stat}\" {params}]\n";
auto cats = GetSplitCategories(tree, nidx);
auto cats_str = PrintCatsAsSet(cats);
auto split_index = tree.SplitIndex(nidx);
Expand All @@ -714,6 +714,7 @@ class GraphvizGenerator : public TreeGenerator {
SuperT::Match(kLabelTemplate, {{"{nid}", std::to_string(nidx)},
{"{fname}", GetFeatureName(fmap_, split_index)},
{"{cond}", cats_str},
{"{stat}", this->NodeStat(tree, nidx)},
{"{params}", param_.condition_node_params}});

result += BuildEdge<true>(tree, nidx, tree.LeftChild(nidx), true);
Expand Down
1 change: 1 addition & 0 deletions tests/cpp/tree/test_tree_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,7 @@ void TestCategoricalTreeDump(std::string format, std::string sep) {
ASSERT_NE(pos, std::string::npos);
pos = str.find(cond_str, pos + 1);
ASSERT_NE(pos, std::string::npos);
ASSERT_NE(str.find("gain"), std::string::npos);

if (format == "json") {
// Make sure it's valid JSON
Expand Down
Loading