decision_tree.cc revision 3f92cad88767b9e0d4febe4e02ad3c31d02d5daa
1// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14// =============================================================================
15#include "tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h"
16#include "tensorflow/core/platform/macros.h"
17
18#include <algorithm>
19
20namespace tensorflow {
21namespace boosted_trees {
22namespace trees {
23
24constexpr int kInvalidLeaf = -1;
25int DecisionTree::Traverse(const DecisionTreeConfig& config,
26                           const int32 sub_root_id,
27                           const utils::Example& example) {
28  if (TF_PREDICT_FALSE(config.nodes_size() <= sub_root_id)) {
29    return kInvalidLeaf;
30  }
31
32  // Traverse tree starting at the provided sub-root.
33  int32 node_id = sub_root_id;
34  while (true) {
35    const auto& current_node = config.nodes(node_id);
36    switch (current_node.node_case()) {
37      case TreeNode::kLeaf: {
38        return node_id;
39      }
40      case TreeNode::kDenseFloatBinarySplit: {
41        const auto& split = current_node.dense_float_binary_split();
42        node_id = example.dense_float_features[split.feature_column()] <=
43                          split.threshold()
44                      ? split.left_id()
45                      : split.right_id();
46        break;
47      }
48      case TreeNode::kSparseFloatBinarySplitDefaultLeft: {
49        const auto& split =
50            current_node.sparse_float_binary_split_default_left().split();
51        auto sparse_feature =
52            example.sparse_float_features[split.feature_column()];
53        node_id = !sparse_feature.has_value() ||
54                          sparse_feature.get_value() <= split.threshold()
55                      ? split.left_id()
56                      : split.right_id();
57        break;
58      }
59      case TreeNode::kSparseFloatBinarySplitDefaultRight: {
60        const auto& split =
61            current_node.sparse_float_binary_split_default_right().split();
62        auto sparse_feature =
63            example.sparse_float_features[split.feature_column()];
64        node_id = sparse_feature.has_value() &&
65                          sparse_feature.get_value() <= split.threshold()
66                      ? split.left_id()
67                      : split.right_id();
68        break;
69      }
70      case TreeNode::kCategoricalIdBinarySplit: {
71        const auto& split = current_node.categorical_id_binary_split();
72        const auto& features =
73            example.sparse_int_features[split.feature_column()];
74        node_id = features.find(split.feature_id()) != features.end()
75                      ? split.left_id()
76                      : split.right_id();
77        break;
78      }
79      case TreeNode::kCategoricalIdSetMembershipBinarySplit: {
80        const auto& split =
81            current_node.categorical_id_set_membership_binary_split();
82        // The new node_id = left_id if a feature is found, or right_id.
83        node_id = split.right_id();
84        for (const int64 feature_id :
85             example.sparse_int_features[split.feature_column()]) {
86          if (std::binary_search(split.feature_ids().begin(),
87                                 split.feature_ids().end(), feature_id)) {
88            node_id = split.left_id();
89            break;
90          }
91        }
92        break;
93      }
94      case TreeNode::NODE_NOT_SET: {
95        LOG(QFATAL) << "Invalid node in tree: " << current_node.DebugString();
96        break;
97      }
98    }
99    DCHECK_NE(node_id, 0) << "Malformed tree, cycles found to root:"
100                          << current_node.DebugString();
101  }
102}
103
104void DecisionTree::LinkChildren(const std::vector<int32>& children,
105                                TreeNode* parent_node) {
106  // Decide how to link children depending on the parent node's type.
107  auto children_it = children.begin();
108  switch (parent_node->node_case()) {
109    case TreeNode::kLeaf: {
110      // Essentially no-op.
111      QCHECK(children.empty()) << "A leaf node cannot have children.";
112      break;
113    }
114    case TreeNode::kDenseFloatBinarySplit: {
115      QCHECK(children.size() == 2)
116          << "A binary split node must have exactly two children.";
117      auto* split = parent_node->mutable_dense_float_binary_split();
118      split->set_left_id(*children_it);
119      split->set_right_id(*++children_it);
120      break;
121    }
122    case TreeNode::kSparseFloatBinarySplitDefaultLeft: {
123      QCHECK(children.size() == 2)
124          << "A binary split node must have exactly two children.";
125      auto* split =
126          parent_node->mutable_sparse_float_binary_split_default_left()
127              ->mutable_split();
128      split->set_left_id(*children_it);
129      split->set_right_id(*++children_it);
130      break;
131    }
132    case TreeNode::kSparseFloatBinarySplitDefaultRight: {
133      QCHECK(children.size() == 2)
134          << "A binary split node must have exactly two children.";
135      auto* split =
136          parent_node->mutable_sparse_float_binary_split_default_right()
137              ->mutable_split();
138      split->set_left_id(*children_it);
139      split->set_right_id(*++children_it);
140      break;
141    }
142    case TreeNode::kCategoricalIdBinarySplit: {
143      QCHECK(children.size() == 2)
144          << "A binary split node must have exactly two children.";
145      auto* split = parent_node->mutable_categorical_id_binary_split();
146      split->set_left_id(*children_it);
147      split->set_right_id(*++children_it);
148      break;
149    }
150    case TreeNode::kCategoricalIdSetMembershipBinarySplit: {
151      QCHECK(children.size() == 2)
152          << "A binary split node must have exactly two children.";
153      auto* split =
154          parent_node->mutable_categorical_id_set_membership_binary_split();
155      split->set_left_id(*children_it);
156      split->set_right_id(*++children_it);
157      break;
158    }
159    case TreeNode::NODE_NOT_SET: {
160      LOG(QFATAL) << "A non-set node cannot have children.";
161      break;
162    }
163  }
164}
165
166std::vector<int32> DecisionTree::GetChildren(const TreeNode& node) {
167  // A node's children depend on its type.
168  switch (node.node_case()) {
169    case TreeNode::kLeaf: {
170      return {};
171    }
172    case TreeNode::kDenseFloatBinarySplit: {
173      const auto& split = node.dense_float_binary_split();
174      return {split.left_id(), split.right_id()};
175    }
176    case TreeNode::kSparseFloatBinarySplitDefaultLeft: {
177      const auto& split = node.sparse_float_binary_split_default_left().split();
178      return {split.left_id(), split.right_id()};
179    }
180    case TreeNode::kSparseFloatBinarySplitDefaultRight: {
181      const auto& split =
182          node.sparse_float_binary_split_default_right().split();
183      return {split.left_id(), split.right_id()};
184    }
185    case TreeNode::kCategoricalIdBinarySplit: {
186      const auto& split = node.categorical_id_binary_split();
187      return {split.left_id(), split.right_id()};
188    }
189    case TreeNode::kCategoricalIdSetMembershipBinarySplit: {
190      const auto& split = node.categorical_id_set_membership_binary_split();
191      return {split.left_id(), split.right_id()};
192    }
193    case TreeNode::NODE_NOT_SET: {
194      return {};
195    }
196  }
197}
198
199}  // namespace trees
200}  // namespace boosted_trees
201}  // namespace tensorflow
202