decision_tree.cc revision c11ea29bcfc2a0036dbbbe5f6dc3dc16f4a01ada
1// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14// =============================================================================
15#include "tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h"
16#include "tensorflow/core/platform/macros.h"
17
18namespace tensorflow {
19namespace boosted_trees {
20namespace trees {
21
22constexpr int kInvalidLeaf = -1;
23int DecisionTree::Traverse(const DecisionTreeConfig& config,
24                           const int32 sub_root_id,
25                           const utils::Example& example) {
26  if (TF_PREDICT_FALSE(config.nodes_size() <= sub_root_id)) {
27    return kInvalidLeaf;
28  }
29
30  // Traverse tree starting at the provided sub-root.
31  int32 node_id = sub_root_id;
32  while (true) {
33    const auto& current_node = config.nodes(node_id);
34    switch (current_node.node_case()) {
35      case TreeNode::kLeaf: {
36        return node_id;
37      }
38      case TreeNode::kDenseFloatBinarySplit: {
39        const auto& split = current_node.dense_float_binary_split();
40        node_id = example.dense_float_features[split.feature_column()] <=
41                          split.threshold()
42                      ? split.left_id()
43                      : split.right_id();
44        break;
45      }
46      case TreeNode::kSparseFloatBinarySplitDefaultLeft: {
47        const auto& split =
48            current_node.sparse_float_binary_split_default_left().split();
49        auto sparse_feature =
50            example.sparse_float_features[split.feature_column()];
51        node_id = !sparse_feature.has_value() ||
52                          sparse_feature.get_value() <= split.threshold()
53                      ? split.left_id()
54                      : split.right_id();
55        break;
56      }
57      case TreeNode::kSparseFloatBinarySplitDefaultRight: {
58        const auto& split =
59            current_node.sparse_float_binary_split_default_right().split();
60        auto sparse_feature =
61            example.sparse_float_features[split.feature_column()];
62        node_id = sparse_feature.has_value() &&
63                          sparse_feature.get_value() <= split.threshold()
64                      ? split.left_id()
65                      : split.right_id();
66        break;
67      }
68      case TreeNode::kCategoricalIdBinarySplit: {
69        const auto& split = current_node.categorical_id_binary_split();
70        node_id = example.sparse_int_features[split.feature_column()].count(
71                      split.feature_id()) > 0
72                      ? split.left_id()
73                      : split.right_id();
74        break;
75      }
76      case TreeNode::kCategoricalIdSetMembershipBinarySplit: {
77        const auto& split =
78            current_node.categorical_id_set_membership_binary_split();
79        bool found = false;
80        for (const int64 feature_id :
81             example.sparse_int_features[split.feature_column()]) {
82          const auto iter =
83              std::lower_bound(split.feature_ids().begin(),
84                               split.feature_ids().end(), feature_id);
85          if (iter != split.feature_ids().end() && *iter == feature_id) {
86            node_id = split.left_id();
87            found = true;
88            break;
89          }
90        }
91        if (!found) {
92          node_id = split.right_id();
93        }
94        break;
95      }
96      case TreeNode::NODE_NOT_SET: {
97        QCHECK(false) << "Invalid node in tree: " << current_node.DebugString();
98        break;
99      }
100    }
101    DCHECK_NE(node_id, 0) << "Malformed tree, cycles found to root:"
102                          << current_node.DebugString();
103  }
104}
105
106void DecisionTree::LinkChildren(const std::vector<int32>& children,
107                                TreeNode* parent_node) {
108  // Decide how to link children depending on the parent node's type.
109  auto children_it = children.begin();
110  switch (parent_node->node_case()) {
111    case TreeNode::kLeaf: {
112      // Essentially no-op.
113      QCHECK(children.empty()) << "A leaf node cannot have children.";
114      break;
115    }
116    case TreeNode::kDenseFloatBinarySplit: {
117      QCHECK(children.size() == 2)
118          << "A binary split node must have exactly two children.";
119      auto* split = parent_node->mutable_dense_float_binary_split();
120      split->set_left_id(*children_it);
121      split->set_right_id(*++children_it);
122      break;
123    }
124    case TreeNode::kSparseFloatBinarySplitDefaultLeft: {
125      QCHECK(children.size() == 2)
126          << "A binary split node must have exactly two children.";
127      auto* split =
128          parent_node->mutable_sparse_float_binary_split_default_left()
129              ->mutable_split();
130      split->set_left_id(*children_it);
131      split->set_right_id(*++children_it);
132      break;
133    }
134    case TreeNode::kSparseFloatBinarySplitDefaultRight: {
135      QCHECK(children.size() == 2)
136          << "A binary split node must have exactly two children.";
137      auto* split =
138          parent_node->mutable_sparse_float_binary_split_default_right()
139              ->mutable_split();
140      split->set_left_id(*children_it);
141      split->set_right_id(*++children_it);
142      break;
143    }
144    case TreeNode::kCategoricalIdBinarySplit: {
145      QCHECK(children.size() == 2)
146          << "A binary split node must have exactly two children.";
147      auto* split = parent_node->mutable_categorical_id_binary_split();
148      split->set_left_id(*children_it);
149      split->set_right_id(*++children_it);
150      break;
151    }
152    case TreeNode::kCategoricalIdSetMembershipBinarySplit: {
153      QCHECK(children.size() == 2)
154          << "A binary split node must have exactly two children.";
155      auto* split =
156          parent_node->mutable_categorical_id_set_membership_binary_split();
157      split->set_left_id(*children_it);
158      split->set_right_id(*++children_it);
159      break;
160    }
161    case TreeNode::NODE_NOT_SET: {
162      QCHECK(false) << "A non-set node cannot have children.";
163      break;
164    }
165  }
166}
167
168std::vector<int32> DecisionTree::GetChildren(const TreeNode& node) {
169  // A node's children depend on its type.
170  switch (node.node_case()) {
171    case TreeNode::kLeaf: {
172      return {};
173    }
174    case TreeNode::kDenseFloatBinarySplit: {
175      const auto& split = node.dense_float_binary_split();
176      return {split.left_id(), split.right_id()};
177    }
178    case TreeNode::kSparseFloatBinarySplitDefaultLeft: {
179      const auto& split = node.sparse_float_binary_split_default_left().split();
180      return {split.left_id(), split.right_id()};
181    }
182    case TreeNode::kSparseFloatBinarySplitDefaultRight: {
183      const auto& split =
184          node.sparse_float_binary_split_default_right().split();
185      return {split.left_id(), split.right_id()};
186    }
187    case TreeNode::kCategoricalIdBinarySplit: {
188      const auto& split = node.categorical_id_binary_split();
189      return {split.left_id(), split.right_id()};
190    }
191    case TreeNode::kCategoricalIdSetMembershipBinarySplit: {
192      const auto& split = node.categorical_id_set_membership_binary_split();
193      return {split.left_id(), split.right_id()};
194    }
195    case TreeNode::NODE_NOT_SET: {
196      return {};
197    }
198  }
199}
200
201}  // namespace trees
202}  // namespace boosted_trees
203}  // namespace tensorflow
204