decision_tree.cc revision bcaf3cf82e23b4318f9d87f17cb5a4536febe564
1// Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14// ============================================================================= 15#include "tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h" 16#include "tensorflow/core/platform/macros.h" 17 18namespace tensorflow { 19namespace boosted_trees { 20namespace trees { 21 22constexpr int kInvalidLeaf = -1; 23int DecisionTree::Traverse(const DecisionTreeConfig& config, 24 const int32 sub_root_id, 25 const utils::Example& example) { 26 if (TF_PREDICT_FALSE(config.nodes_size() <= sub_root_id)) { 27 return kInvalidLeaf; 28 } 29 30 // Traverse tree starting at the provided sub-root. 31 int32 node_id = sub_root_id; 32 while (true) { 33 const auto& current_node = config.nodes(node_id); 34 switch (current_node.node_case()) { 35 case TreeNode::kLeaf: { 36 return node_id; 37 } 38 case TreeNode::kDenseFloatBinarySplit: { 39 const auto& split = current_node.dense_float_binary_split(); 40 node_id = example.dense_float_features[split.feature_column()] <= 41 split.threshold() 42 ? split.left_id() 43 : split.right_id(); 44 break; 45 } 46 case TreeNode::kSparseFloatBinarySplitDefaultLeft: { 47 const auto& split = 48 current_node.sparse_float_binary_split_default_left().split(); 49 auto sparse_feature = 50 example.sparse_float_features[split.feature_column()]; 51 node_id = !sparse_feature.has_value() || 52 sparse_feature.get_value() <= split.threshold() 53 ? split.left_id() 54 : split.right_id(); 55 break; 56 } 57 case TreeNode::kSparseFloatBinarySplitDefaultRight: { 58 const auto& split = 59 current_node.sparse_float_binary_split_default_right().split(); 60 auto sparse_feature = 61 example.sparse_float_features[split.feature_column()]; 62 node_id = sparse_feature.has_value() && 63 sparse_feature.get_value() <= split.threshold() 64 ? split.left_id() 65 : split.right_id(); 66 break; 67 } 68 case TreeNode::kCategoricalIdBinarySplit: { 69 const auto& split = current_node.categorical_id_binary_split(); 70 const auto& features = 71 example.sparse_int_features[split.feature_column()]; 72 node_id = features.find(split.feature_id()) != features.end() 73 ? split.left_id() 74 : split.right_id(); 75 break; 76 } 77 case TreeNode::kCategoricalIdSetMembershipBinarySplit: { 78 const auto& split = 79 current_node.categorical_id_set_membership_binary_split(); 80 // The new node_id = left_id if a feature is found, or right_id. 81 node_id = split.right_id(); 82 for (const int64 feature_id : 83 example.sparse_int_features[split.feature_column()]) { 84 if (std::binary_search(split.feature_ids().begin(), 85 split.feature_ids().end(), feature_id)) { 86 node_id = split.left_id(); 87 break; 88 } 89 } 90 break; 91 } 92 case TreeNode::NODE_NOT_SET: { 93 QCHECK(false) << "Invalid node in tree: " << current_node.DebugString(); 94 break; 95 } 96 } 97 DCHECK_NE(node_id, 0) << "Malformed tree, cycles found to root:" 98 << current_node.DebugString(); 99 } 100} 101 102void DecisionTree::LinkChildren(const std::vector<int32>& children, 103 TreeNode* parent_node) { 104 // Decide how to link children depending on the parent node's type. 105 auto children_it = children.begin(); 106 switch (parent_node->node_case()) { 107 case TreeNode::kLeaf: { 108 // Essentially no-op. 109 QCHECK(children.empty()) << "A leaf node cannot have children."; 110 break; 111 } 112 case TreeNode::kDenseFloatBinarySplit: { 113 QCHECK(children.size() == 2) 114 << "A binary split node must have exactly two children."; 115 auto* split = parent_node->mutable_dense_float_binary_split(); 116 split->set_left_id(*children_it); 117 split->set_right_id(*++children_it); 118 break; 119 } 120 case TreeNode::kSparseFloatBinarySplitDefaultLeft: { 121 QCHECK(children.size() == 2) 122 << "A binary split node must have exactly two children."; 123 auto* split = 124 parent_node->mutable_sparse_float_binary_split_default_left() 125 ->mutable_split(); 126 split->set_left_id(*children_it); 127 split->set_right_id(*++children_it); 128 break; 129 } 130 case TreeNode::kSparseFloatBinarySplitDefaultRight: { 131 QCHECK(children.size() == 2) 132 << "A binary split node must have exactly two children."; 133 auto* split = 134 parent_node->mutable_sparse_float_binary_split_default_right() 135 ->mutable_split(); 136 split->set_left_id(*children_it); 137 split->set_right_id(*++children_it); 138 break; 139 } 140 case TreeNode::kCategoricalIdBinarySplit: { 141 QCHECK(children.size() == 2) 142 << "A binary split node must have exactly two children."; 143 auto* split = parent_node->mutable_categorical_id_binary_split(); 144 split->set_left_id(*children_it); 145 split->set_right_id(*++children_it); 146 break; 147 } 148 case TreeNode::kCategoricalIdSetMembershipBinarySplit: { 149 QCHECK(children.size() == 2) 150 << "A binary split node must have exactly two children."; 151 auto* split = 152 parent_node->mutable_categorical_id_set_membership_binary_split(); 153 split->set_left_id(*children_it); 154 split->set_right_id(*++children_it); 155 break; 156 } 157 case TreeNode::NODE_NOT_SET: { 158 QCHECK(false) << "A non-set node cannot have children."; 159 break; 160 } 161 } 162} 163 164std::vector<int32> DecisionTree::GetChildren(const TreeNode& node) { 165 // A node's children depend on its type. 166 switch (node.node_case()) { 167 case TreeNode::kLeaf: { 168 return {}; 169 } 170 case TreeNode::kDenseFloatBinarySplit: { 171 const auto& split = node.dense_float_binary_split(); 172 return {split.left_id(), split.right_id()}; 173 } 174 case TreeNode::kSparseFloatBinarySplitDefaultLeft: { 175 const auto& split = node.sparse_float_binary_split_default_left().split(); 176 return {split.left_id(), split.right_id()}; 177 } 178 case TreeNode::kSparseFloatBinarySplitDefaultRight: { 179 const auto& split = 180 node.sparse_float_binary_split_default_right().split(); 181 return {split.left_id(), split.right_id()}; 182 } 183 case TreeNode::kCategoricalIdBinarySplit: { 184 const auto& split = node.categorical_id_binary_split(); 185 return {split.left_id(), split.right_id()}; 186 } 187 case TreeNode::kCategoricalIdSetMembershipBinarySplit: { 188 const auto& split = node.categorical_id_set_membership_binary_split(); 189 return {split.left_id(), split.right_id()}; 190 } 191 case TreeNode::NODE_NOT_SET: { 192 return {}; 193 } 194 } 195} 196 197} // namespace trees 198} // namespace boosted_trees 199} // namespace tensorflow 200