decision_tree.cc revision 50be7aa7d72ded57c11c705e9de80da2bdc2220b
1// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14// =============================================================================
15#include "tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h"
16#include "tensorflow/core/platform/macros.h"
17
18namespace tensorflow {
19namespace boosted_trees {
20namespace trees {
21
22constexpr int kInvalidLeaf = -1;
23int DecisionTree::Traverse(const DecisionTreeConfig& config,
24                           const int32 sub_root_id,
25                           const utils::Example& example) {
26  if (TF_PREDICT_FALSE(config.nodes_size() <= sub_root_id)) {
27    return kInvalidLeaf;
28  }
29
30  // Traverse tree starting at the provided sub-root.
31  int32 node_id = sub_root_id;
32  while (true) {
33    const auto& current_node = config.nodes(node_id);
34    switch (current_node.node_case()) {
35      case TreeNode::kLeaf: {
36        return node_id;
37      }
38      case TreeNode::kDenseFloatBinarySplit: {
39        const auto& split = current_node.dense_float_binary_split();
40        node_id = example.dense_float_features[split.feature_column()] <=
41                          split.threshold()
42                      ? split.left_id()
43                      : split.right_id();
44        break;
45      }
46      case TreeNode::kSparseFloatBinarySplitDefaultLeft: {
47        const auto& split =
48            current_node.sparse_float_binary_split_default_left().split();
49        auto sparse_feature =
50            example.sparse_float_features[split.feature_column()];
51        node_id = !sparse_feature.has_value() ||
52                          sparse_feature.get_value() <= split.threshold()
53                      ? split.left_id()
54                      : split.right_id();
55        break;
56      }
57      case TreeNode::kSparseFloatBinarySplitDefaultRight: {
58        const auto& split =
59            current_node.sparse_float_binary_split_default_right().split();
60        auto sparse_feature =
61            example.sparse_float_features[split.feature_column()];
62        node_id = sparse_feature.has_value() &&
63                          sparse_feature.get_value() <= split.threshold()
64                      ? split.left_id()
65                      : split.right_id();
66        break;
67      }
68      case TreeNode::kCategoricalIdBinarySplit: {
69        const auto& split = current_node.categorical_id_binary_split();
70        node_id = example.sparse_int_features[split.feature_column()].count(
71                      split.feature_id()) > 0
72                      ? split.left_id()
73                      : split.right_id();
74        break;
75      }
76      case TreeNode::NODE_NOT_SET: {
77        QCHECK(false) << "Invalid node in tree: " << current_node.DebugString();
78        break;
79      }
80    }
81    DCHECK_NE(node_id, 0) << "Malformed tree, cycles found to root:"
82                          << current_node.DebugString();
83  }
84}
85
86void DecisionTree::LinkChildren(const std::vector<int32>& children,
87                                TreeNode* parent_node) {
88  // Decide how to link children depending on the parent node's type.
89  auto children_it = children.begin();
90  switch (parent_node->node_case()) {
91    case TreeNode::kLeaf: {
92      // Essentially no-op.
93      QCHECK(children.empty()) << "A leaf node cannot have children.";
94      break;
95    }
96    case TreeNode::kDenseFloatBinarySplit: {
97      QCHECK(children.size() == 2)
98          << "A binary split node must have exactly two children.";
99      auto* split = parent_node->mutable_dense_float_binary_split();
100      split->set_left_id(*children_it);
101      split->set_right_id(*++children_it);
102      break;
103    }
104    case TreeNode::kSparseFloatBinarySplitDefaultLeft: {
105      QCHECK(children.size() == 2)
106          << "A binary split node must have exactly two children.";
107      auto* split =
108          parent_node->mutable_sparse_float_binary_split_default_left()
109              ->mutable_split();
110      split->set_left_id(*children_it);
111      split->set_right_id(*++children_it);
112      break;
113    }
114    case TreeNode::kSparseFloatBinarySplitDefaultRight: {
115      QCHECK(children.size() == 2)
116          << "A binary split node must have exactly two children.";
117      auto* split =
118          parent_node->mutable_sparse_float_binary_split_default_right()
119              ->mutable_split();
120      split->set_left_id(*children_it);
121      split->set_right_id(*++children_it);
122      break;
123    }
124    case TreeNode::kCategoricalIdBinarySplit: {
125      QCHECK(children.size() == 2)
126          << "A binary split node must have exactly two children.";
127      auto* split = parent_node->mutable_categorical_id_binary_split();
128      split->set_left_id(*children_it);
129      split->set_right_id(*++children_it);
130      break;
131    }
132    case TreeNode::NODE_NOT_SET: {
133      QCHECK(false) << "A non-set node cannot have children.";
134      break;
135    }
136  }
137}
138
139std::vector<int32> DecisionTree::GetChildren(const TreeNode& node) {
140  // A node's children depend on its type.
141  switch (node.node_case()) {
142    case TreeNode::kLeaf: {
143      return {};
144    }
145    case TreeNode::kDenseFloatBinarySplit: {
146      const auto& split = node.dense_float_binary_split();
147      return {split.left_id(), split.right_id()};
148    }
149    case TreeNode::kSparseFloatBinarySplitDefaultLeft: {
150      const auto& split = node.sparse_float_binary_split_default_left().split();
151      return {split.left_id(), split.right_id()};
152    }
153    case TreeNode::kSparseFloatBinarySplitDefaultRight: {
154      const auto& split =
155          node.sparse_float_binary_split_default_right().split();
156      return {split.left_id(), split.right_id()};
157    }
158    case TreeNode::kCategoricalIdBinarySplit: {
159      const auto& split = node.categorical_id_binary_split();
160      return {split.left_id(), split.right_id()};
161    }
162    case TreeNode::NODE_NOT_SET: {
163      return {};
164    }
165  }
166}
167
168}  // namespace trees
169}  // namespace boosted_trees
170}  // namespace tensorflow
171