decision_tree.cc revision 3f92cad88767b9e0d4febe4e02ad3c31d02d5daa
1// Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14// ============================================================================= 15#include "tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h" 16#include "tensorflow/core/platform/macros.h" 17 18#include <algorithm> 19 20namespace tensorflow { 21namespace boosted_trees { 22namespace trees { 23 24constexpr int kInvalidLeaf = -1; 25int DecisionTree::Traverse(const DecisionTreeConfig& config, 26 const int32 sub_root_id, 27 const utils::Example& example) { 28 if (TF_PREDICT_FALSE(config.nodes_size() <= sub_root_id)) { 29 return kInvalidLeaf; 30 } 31 32 // Traverse tree starting at the provided sub-root. 33 int32 node_id = sub_root_id; 34 while (true) { 35 const auto& current_node = config.nodes(node_id); 36 switch (current_node.node_case()) { 37 case TreeNode::kLeaf: { 38 return node_id; 39 } 40 case TreeNode::kDenseFloatBinarySplit: { 41 const auto& split = current_node.dense_float_binary_split(); 42 node_id = example.dense_float_features[split.feature_column()] <= 43 split.threshold() 44 ? split.left_id() 45 : split.right_id(); 46 break; 47 } 48 case TreeNode::kSparseFloatBinarySplitDefaultLeft: { 49 const auto& split = 50 current_node.sparse_float_binary_split_default_left().split(); 51 auto sparse_feature = 52 example.sparse_float_features[split.feature_column()]; 53 node_id = !sparse_feature.has_value() || 54 sparse_feature.get_value() <= split.threshold() 55 ? split.left_id() 56 : split.right_id(); 57 break; 58 } 59 case TreeNode::kSparseFloatBinarySplitDefaultRight: { 60 const auto& split = 61 current_node.sparse_float_binary_split_default_right().split(); 62 auto sparse_feature = 63 example.sparse_float_features[split.feature_column()]; 64 node_id = sparse_feature.has_value() && 65 sparse_feature.get_value() <= split.threshold() 66 ? split.left_id() 67 : split.right_id(); 68 break; 69 } 70 case TreeNode::kCategoricalIdBinarySplit: { 71 const auto& split = current_node.categorical_id_binary_split(); 72 const auto& features = 73 example.sparse_int_features[split.feature_column()]; 74 node_id = features.find(split.feature_id()) != features.end() 75 ? split.left_id() 76 : split.right_id(); 77 break; 78 } 79 case TreeNode::kCategoricalIdSetMembershipBinarySplit: { 80 const auto& split = 81 current_node.categorical_id_set_membership_binary_split(); 82 // The new node_id = left_id if a feature is found, or right_id. 83 node_id = split.right_id(); 84 for (const int64 feature_id : 85 example.sparse_int_features[split.feature_column()]) { 86 if (std::binary_search(split.feature_ids().begin(), 87 split.feature_ids().end(), feature_id)) { 88 node_id = split.left_id(); 89 break; 90 } 91 } 92 break; 93 } 94 case TreeNode::NODE_NOT_SET: { 95 LOG(QFATAL) << "Invalid node in tree: " << current_node.DebugString(); 96 break; 97 } 98 } 99 DCHECK_NE(node_id, 0) << "Malformed tree, cycles found to root:" 100 << current_node.DebugString(); 101 } 102} 103 104void DecisionTree::LinkChildren(const std::vector<int32>& children, 105 TreeNode* parent_node) { 106 // Decide how to link children depending on the parent node's type. 107 auto children_it = children.begin(); 108 switch (parent_node->node_case()) { 109 case TreeNode::kLeaf: { 110 // Essentially no-op. 111 QCHECK(children.empty()) << "A leaf node cannot have children."; 112 break; 113 } 114 case TreeNode::kDenseFloatBinarySplit: { 115 QCHECK(children.size() == 2) 116 << "A binary split node must have exactly two children."; 117 auto* split = parent_node->mutable_dense_float_binary_split(); 118 split->set_left_id(*children_it); 119 split->set_right_id(*++children_it); 120 break; 121 } 122 case TreeNode::kSparseFloatBinarySplitDefaultLeft: { 123 QCHECK(children.size() == 2) 124 << "A binary split node must have exactly two children."; 125 auto* split = 126 parent_node->mutable_sparse_float_binary_split_default_left() 127 ->mutable_split(); 128 split->set_left_id(*children_it); 129 split->set_right_id(*++children_it); 130 break; 131 } 132 case TreeNode::kSparseFloatBinarySplitDefaultRight: { 133 QCHECK(children.size() == 2) 134 << "A binary split node must have exactly two children."; 135 auto* split = 136 parent_node->mutable_sparse_float_binary_split_default_right() 137 ->mutable_split(); 138 split->set_left_id(*children_it); 139 split->set_right_id(*++children_it); 140 break; 141 } 142 case TreeNode::kCategoricalIdBinarySplit: { 143 QCHECK(children.size() == 2) 144 << "A binary split node must have exactly two children."; 145 auto* split = parent_node->mutable_categorical_id_binary_split(); 146 split->set_left_id(*children_it); 147 split->set_right_id(*++children_it); 148 break; 149 } 150 case TreeNode::kCategoricalIdSetMembershipBinarySplit: { 151 QCHECK(children.size() == 2) 152 << "A binary split node must have exactly two children."; 153 auto* split = 154 parent_node->mutable_categorical_id_set_membership_binary_split(); 155 split->set_left_id(*children_it); 156 split->set_right_id(*++children_it); 157 break; 158 } 159 case TreeNode::NODE_NOT_SET: { 160 LOG(QFATAL) << "A non-set node cannot have children."; 161 break; 162 } 163 } 164} 165 166std::vector<int32> DecisionTree::GetChildren(const TreeNode& node) { 167 // A node's children depend on its type. 168 switch (node.node_case()) { 169 case TreeNode::kLeaf: { 170 return {}; 171 } 172 case TreeNode::kDenseFloatBinarySplit: { 173 const auto& split = node.dense_float_binary_split(); 174 return {split.left_id(), split.right_id()}; 175 } 176 case TreeNode::kSparseFloatBinarySplitDefaultLeft: { 177 const auto& split = node.sparse_float_binary_split_default_left().split(); 178 return {split.left_id(), split.right_id()}; 179 } 180 case TreeNode::kSparseFloatBinarySplitDefaultRight: { 181 const auto& split = 182 node.sparse_float_binary_split_default_right().split(); 183 return {split.left_id(), split.right_id()}; 184 } 185 case TreeNode::kCategoricalIdBinarySplit: { 186 const auto& split = node.categorical_id_binary_split(); 187 return {split.left_id(), split.right_id()}; 188 } 189 case TreeNode::kCategoricalIdSetMembershipBinarySplit: { 190 const auto& split = node.categorical_id_set_membership_binary_split(); 191 return {split.left_id(), split.right_id()}; 192 } 193 case TreeNode::NODE_NOT_SET: { 194 return {}; 195 } 196 } 197} 198 199} // namespace trees 200} // namespace boosted_trees 201} // namespace tensorflow 202