decision_tree.cc revision bcaf3cf82e23b4318f9d87f17cb5a4536febe564
1// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14// =============================================================================
15#include "tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h"
16#include "tensorflow/core/platform/macros.h"
17
18namespace tensorflow {
19namespace boosted_trees {
20namespace trees {
21
22constexpr int kInvalidLeaf = -1;
23int DecisionTree::Traverse(const DecisionTreeConfig& config,
24                           const int32 sub_root_id,
25                           const utils::Example& example) {
26  if (TF_PREDICT_FALSE(config.nodes_size() <= sub_root_id)) {
27    return kInvalidLeaf;
28  }
29
30  // Traverse tree starting at the provided sub-root.
31  int32 node_id = sub_root_id;
32  while (true) {
33    const auto& current_node = config.nodes(node_id);
34    switch (current_node.node_case()) {
35      case TreeNode::kLeaf: {
36        return node_id;
37      }
38      case TreeNode::kDenseFloatBinarySplit: {
39        const auto& split = current_node.dense_float_binary_split();
40        node_id = example.dense_float_features[split.feature_column()] <=
41                          split.threshold()
42                      ? split.left_id()
43                      : split.right_id();
44        break;
45      }
46      case TreeNode::kSparseFloatBinarySplitDefaultLeft: {
47        const auto& split =
48            current_node.sparse_float_binary_split_default_left().split();
49        auto sparse_feature =
50            example.sparse_float_features[split.feature_column()];
51        node_id = !sparse_feature.has_value() ||
52                          sparse_feature.get_value() <= split.threshold()
53                      ? split.left_id()
54                      : split.right_id();
55        break;
56      }
57      case TreeNode::kSparseFloatBinarySplitDefaultRight: {
58        const auto& split =
59            current_node.sparse_float_binary_split_default_right().split();
60        auto sparse_feature =
61            example.sparse_float_features[split.feature_column()];
62        node_id = sparse_feature.has_value() &&
63                          sparse_feature.get_value() <= split.threshold()
64                      ? split.left_id()
65                      : split.right_id();
66        break;
67      }
68      case TreeNode::kCategoricalIdBinarySplit: {
69        const auto& split = current_node.categorical_id_binary_split();
70        const auto& features =
71            example.sparse_int_features[split.feature_column()];
72        node_id = features.find(split.feature_id()) != features.end()
73                      ? split.left_id()
74                      : split.right_id();
75        break;
76      }
77      case TreeNode::kCategoricalIdSetMembershipBinarySplit: {
78        const auto& split =
79            current_node.categorical_id_set_membership_binary_split();
80        // The new node_id = left_id if a feature is found, or right_id.
81        node_id = split.right_id();
82        for (const int64 feature_id :
83             example.sparse_int_features[split.feature_column()]) {
84          if (std::binary_search(split.feature_ids().begin(),
85                                 split.feature_ids().end(), feature_id)) {
86            node_id = split.left_id();
87            break;
88          }
89        }
90        break;
91      }
92      case TreeNode::NODE_NOT_SET: {
93        QCHECK(false) << "Invalid node in tree: " << current_node.DebugString();
94        break;
95      }
96    }
97    DCHECK_NE(node_id, 0) << "Malformed tree, cycles found to root:"
98                          << current_node.DebugString();
99  }
100}
101
102void DecisionTree::LinkChildren(const std::vector<int32>& children,
103                                TreeNode* parent_node) {
104  // Decide how to link children depending on the parent node's type.
105  auto children_it = children.begin();
106  switch (parent_node->node_case()) {
107    case TreeNode::kLeaf: {
108      // Essentially no-op.
109      QCHECK(children.empty()) << "A leaf node cannot have children.";
110      break;
111    }
112    case TreeNode::kDenseFloatBinarySplit: {
113      QCHECK(children.size() == 2)
114          << "A binary split node must have exactly two children.";
115      auto* split = parent_node->mutable_dense_float_binary_split();
116      split->set_left_id(*children_it);
117      split->set_right_id(*++children_it);
118      break;
119    }
120    case TreeNode::kSparseFloatBinarySplitDefaultLeft: {
121      QCHECK(children.size() == 2)
122          << "A binary split node must have exactly two children.";
123      auto* split =
124          parent_node->mutable_sparse_float_binary_split_default_left()
125              ->mutable_split();
126      split->set_left_id(*children_it);
127      split->set_right_id(*++children_it);
128      break;
129    }
130    case TreeNode::kSparseFloatBinarySplitDefaultRight: {
131      QCHECK(children.size() == 2)
132          << "A binary split node must have exactly two children.";
133      auto* split =
134          parent_node->mutable_sparse_float_binary_split_default_right()
135              ->mutable_split();
136      split->set_left_id(*children_it);
137      split->set_right_id(*++children_it);
138      break;
139    }
140    case TreeNode::kCategoricalIdBinarySplit: {
141      QCHECK(children.size() == 2)
142          << "A binary split node must have exactly two children.";
143      auto* split = parent_node->mutable_categorical_id_binary_split();
144      split->set_left_id(*children_it);
145      split->set_right_id(*++children_it);
146      break;
147    }
148    case TreeNode::kCategoricalIdSetMembershipBinarySplit: {
149      QCHECK(children.size() == 2)
150          << "A binary split node must have exactly two children.";
151      auto* split =
152          parent_node->mutable_categorical_id_set_membership_binary_split();
153      split->set_left_id(*children_it);
154      split->set_right_id(*++children_it);
155      break;
156    }
157    case TreeNode::NODE_NOT_SET: {
158      QCHECK(false) << "A non-set node cannot have children.";
159      break;
160    }
161  }
162}
163
164std::vector<int32> DecisionTree::GetChildren(const TreeNode& node) {
165  // A node's children depend on its type.
166  switch (node.node_case()) {
167    case TreeNode::kLeaf: {
168      return {};
169    }
170    case TreeNode::kDenseFloatBinarySplit: {
171      const auto& split = node.dense_float_binary_split();
172      return {split.left_id(), split.right_id()};
173    }
174    case TreeNode::kSparseFloatBinarySplitDefaultLeft: {
175      const auto& split = node.sparse_float_binary_split_default_left().split();
176      return {split.left_id(), split.right_id()};
177    }
178    case TreeNode::kSparseFloatBinarySplitDefaultRight: {
179      const auto& split =
180          node.sparse_float_binary_split_default_right().split();
181      return {split.left_id(), split.right_id()};
182    }
183    case TreeNode::kCategoricalIdBinarySplit: {
184      const auto& split = node.categorical_id_binary_split();
185      return {split.left_id(), split.right_id()};
186    }
187    case TreeNode::kCategoricalIdSetMembershipBinarySplit: {
188      const auto& split = node.categorical_id_set_membership_binary_split();
189      return {split.left_id(), split.right_id()};
190    }
191    case TreeNode::NODE_NOT_SET: {
192      return {};
193    }
194  }
195}
196
197}  // namespace trees
198}  // namespace boosted_trees
199}  // namespace tensorflow
200