decision_tree.cc revision c11ea29bcfc2a0036dbbbe5f6dc3dc16f4a01ada
1// Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14// ============================================================================= 15#include "tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h" 16#include "tensorflow/core/platform/macros.h" 17 18namespace tensorflow { 19namespace boosted_trees { 20namespace trees { 21 22constexpr int kInvalidLeaf = -1; 23int DecisionTree::Traverse(const DecisionTreeConfig& config, 24 const int32 sub_root_id, 25 const utils::Example& example) { 26 if (TF_PREDICT_FALSE(config.nodes_size() <= sub_root_id)) { 27 return kInvalidLeaf; 28 } 29 30 // Traverse tree starting at the provided sub-root. 31 int32 node_id = sub_root_id; 32 while (true) { 33 const auto& current_node = config.nodes(node_id); 34 switch (current_node.node_case()) { 35 case TreeNode::kLeaf: { 36 return node_id; 37 } 38 case TreeNode::kDenseFloatBinarySplit: { 39 const auto& split = current_node.dense_float_binary_split(); 40 node_id = example.dense_float_features[split.feature_column()] <= 41 split.threshold() 42 ? split.left_id() 43 : split.right_id(); 44 break; 45 } 46 case TreeNode::kSparseFloatBinarySplitDefaultLeft: { 47 const auto& split = 48 current_node.sparse_float_binary_split_default_left().split(); 49 auto sparse_feature = 50 example.sparse_float_features[split.feature_column()]; 51 node_id = !sparse_feature.has_value() || 52 sparse_feature.get_value() <= split.threshold() 53 ? split.left_id() 54 : split.right_id(); 55 break; 56 } 57 case TreeNode::kSparseFloatBinarySplitDefaultRight: { 58 const auto& split = 59 current_node.sparse_float_binary_split_default_right().split(); 60 auto sparse_feature = 61 example.sparse_float_features[split.feature_column()]; 62 node_id = sparse_feature.has_value() && 63 sparse_feature.get_value() <= split.threshold() 64 ? split.left_id() 65 : split.right_id(); 66 break; 67 } 68 case TreeNode::kCategoricalIdBinarySplit: { 69 const auto& split = current_node.categorical_id_binary_split(); 70 node_id = example.sparse_int_features[split.feature_column()].count( 71 split.feature_id()) > 0 72 ? split.left_id() 73 : split.right_id(); 74 break; 75 } 76 case TreeNode::kCategoricalIdSetMembershipBinarySplit: { 77 const auto& split = 78 current_node.categorical_id_set_membership_binary_split(); 79 bool found = false; 80 for (const int64 feature_id : 81 example.sparse_int_features[split.feature_column()]) { 82 const auto iter = 83 std::lower_bound(split.feature_ids().begin(), 84 split.feature_ids().end(), feature_id); 85 if (iter != split.feature_ids().end() && *iter == feature_id) { 86 node_id = split.left_id(); 87 found = true; 88 break; 89 } 90 } 91 if (!found) { 92 node_id = split.right_id(); 93 } 94 break; 95 } 96 case TreeNode::NODE_NOT_SET: { 97 QCHECK(false) << "Invalid node in tree: " << current_node.DebugString(); 98 break; 99 } 100 } 101 DCHECK_NE(node_id, 0) << "Malformed tree, cycles found to root:" 102 << current_node.DebugString(); 103 } 104} 105 106void DecisionTree::LinkChildren(const std::vector<int32>& children, 107 TreeNode* parent_node) { 108 // Decide how to link children depending on the parent node's type. 109 auto children_it = children.begin(); 110 switch (parent_node->node_case()) { 111 case TreeNode::kLeaf: { 112 // Essentially no-op. 113 QCHECK(children.empty()) << "A leaf node cannot have children."; 114 break; 115 } 116 case TreeNode::kDenseFloatBinarySplit: { 117 QCHECK(children.size() == 2) 118 << "A binary split node must have exactly two children."; 119 auto* split = parent_node->mutable_dense_float_binary_split(); 120 split->set_left_id(*children_it); 121 split->set_right_id(*++children_it); 122 break; 123 } 124 case TreeNode::kSparseFloatBinarySplitDefaultLeft: { 125 QCHECK(children.size() == 2) 126 << "A binary split node must have exactly two children."; 127 auto* split = 128 parent_node->mutable_sparse_float_binary_split_default_left() 129 ->mutable_split(); 130 split->set_left_id(*children_it); 131 split->set_right_id(*++children_it); 132 break; 133 } 134 case TreeNode::kSparseFloatBinarySplitDefaultRight: { 135 QCHECK(children.size() == 2) 136 << "A binary split node must have exactly two children."; 137 auto* split = 138 parent_node->mutable_sparse_float_binary_split_default_right() 139 ->mutable_split(); 140 split->set_left_id(*children_it); 141 split->set_right_id(*++children_it); 142 break; 143 } 144 case TreeNode::kCategoricalIdBinarySplit: { 145 QCHECK(children.size() == 2) 146 << "A binary split node must have exactly two children."; 147 auto* split = parent_node->mutable_categorical_id_binary_split(); 148 split->set_left_id(*children_it); 149 split->set_right_id(*++children_it); 150 break; 151 } 152 case TreeNode::kCategoricalIdSetMembershipBinarySplit: { 153 QCHECK(children.size() == 2) 154 << "A binary split node must have exactly two children."; 155 auto* split = 156 parent_node->mutable_categorical_id_set_membership_binary_split(); 157 split->set_left_id(*children_it); 158 split->set_right_id(*++children_it); 159 break; 160 } 161 case TreeNode::NODE_NOT_SET: { 162 QCHECK(false) << "A non-set node cannot have children."; 163 break; 164 } 165 } 166} 167 168std::vector<int32> DecisionTree::GetChildren(const TreeNode& node) { 169 // A node's children depend on its type. 170 switch (node.node_case()) { 171 case TreeNode::kLeaf: { 172 return {}; 173 } 174 case TreeNode::kDenseFloatBinarySplit: { 175 const auto& split = node.dense_float_binary_split(); 176 return {split.left_id(), split.right_id()}; 177 } 178 case TreeNode::kSparseFloatBinarySplitDefaultLeft: { 179 const auto& split = node.sparse_float_binary_split_default_left().split(); 180 return {split.left_id(), split.right_id()}; 181 } 182 case TreeNode::kSparseFloatBinarySplitDefaultRight: { 183 const auto& split = 184 node.sparse_float_binary_split_default_right().split(); 185 return {split.left_id(), split.right_id()}; 186 } 187 case TreeNode::kCategoricalIdBinarySplit: { 188 const auto& split = node.categorical_id_binary_split(); 189 return {split.left_id(), split.right_id()}; 190 } 191 case TreeNode::kCategoricalIdSetMembershipBinarySplit: { 192 const auto& split = node.categorical_id_set_membership_binary_split(); 193 return {split.left_id(), split.right_id()}; 194 } 195 case TreeNode::NODE_NOT_SET: { 196 return {}; 197 } 198 } 199} 200 201} // namespace trees 202} // namespace boosted_trees 203} // namespace tensorflow 204