decision_tree.cc revision 50be7aa7d72ded57c11c705e9de80da2bdc2220b
1// Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14// ============================================================================= 15#include "tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h" 16#include "tensorflow/core/platform/macros.h" 17 18namespace tensorflow { 19namespace boosted_trees { 20namespace trees { 21 22constexpr int kInvalidLeaf = -1; 23int DecisionTree::Traverse(const DecisionTreeConfig& config, 24 const int32 sub_root_id, 25 const utils::Example& example) { 26 if (TF_PREDICT_FALSE(config.nodes_size() <= sub_root_id)) { 27 return kInvalidLeaf; 28 } 29 30 // Traverse tree starting at the provided sub-root. 31 int32 node_id = sub_root_id; 32 while (true) { 33 const auto& current_node = config.nodes(node_id); 34 switch (current_node.node_case()) { 35 case TreeNode::kLeaf: { 36 return node_id; 37 } 38 case TreeNode::kDenseFloatBinarySplit: { 39 const auto& split = current_node.dense_float_binary_split(); 40 node_id = example.dense_float_features[split.feature_column()] <= 41 split.threshold() 42 ? split.left_id() 43 : split.right_id(); 44 break; 45 } 46 case TreeNode::kSparseFloatBinarySplitDefaultLeft: { 47 const auto& split = 48 current_node.sparse_float_binary_split_default_left().split(); 49 auto sparse_feature = 50 example.sparse_float_features[split.feature_column()]; 51 node_id = !sparse_feature.has_value() || 52 sparse_feature.get_value() <= split.threshold() 53 ? split.left_id() 54 : split.right_id(); 55 break; 56 } 57 case TreeNode::kSparseFloatBinarySplitDefaultRight: { 58 const auto& split = 59 current_node.sparse_float_binary_split_default_right().split(); 60 auto sparse_feature = 61 example.sparse_float_features[split.feature_column()]; 62 node_id = sparse_feature.has_value() && 63 sparse_feature.get_value() <= split.threshold() 64 ? split.left_id() 65 : split.right_id(); 66 break; 67 } 68 case TreeNode::kCategoricalIdBinarySplit: { 69 const auto& split = current_node.categorical_id_binary_split(); 70 node_id = example.sparse_int_features[split.feature_column()].count( 71 split.feature_id()) > 0 72 ? split.left_id() 73 : split.right_id(); 74 break; 75 } 76 case TreeNode::NODE_NOT_SET: { 77 QCHECK(false) << "Invalid node in tree: " << current_node.DebugString(); 78 break; 79 } 80 } 81 DCHECK_NE(node_id, 0) << "Malformed tree, cycles found to root:" 82 << current_node.DebugString(); 83 } 84} 85 86void DecisionTree::LinkChildren(const std::vector<int32>& children, 87 TreeNode* parent_node) { 88 // Decide how to link children depending on the parent node's type. 89 auto children_it = children.begin(); 90 switch (parent_node->node_case()) { 91 case TreeNode::kLeaf: { 92 // Essentially no-op. 93 QCHECK(children.empty()) << "A leaf node cannot have children."; 94 break; 95 } 96 case TreeNode::kDenseFloatBinarySplit: { 97 QCHECK(children.size() == 2) 98 << "A binary split node must have exactly two children."; 99 auto* split = parent_node->mutable_dense_float_binary_split(); 100 split->set_left_id(*children_it); 101 split->set_right_id(*++children_it); 102 break; 103 } 104 case TreeNode::kSparseFloatBinarySplitDefaultLeft: { 105 QCHECK(children.size() == 2) 106 << "A binary split node must have exactly two children."; 107 auto* split = 108 parent_node->mutable_sparse_float_binary_split_default_left() 109 ->mutable_split(); 110 split->set_left_id(*children_it); 111 split->set_right_id(*++children_it); 112 break; 113 } 114 case TreeNode::kSparseFloatBinarySplitDefaultRight: { 115 QCHECK(children.size() == 2) 116 << "A binary split node must have exactly two children."; 117 auto* split = 118 parent_node->mutable_sparse_float_binary_split_default_right() 119 ->mutable_split(); 120 split->set_left_id(*children_it); 121 split->set_right_id(*++children_it); 122 break; 123 } 124 case TreeNode::kCategoricalIdBinarySplit: { 125 QCHECK(children.size() == 2) 126 << "A binary split node must have exactly two children."; 127 auto* split = parent_node->mutable_categorical_id_binary_split(); 128 split->set_left_id(*children_it); 129 split->set_right_id(*++children_it); 130 break; 131 } 132 case TreeNode::NODE_NOT_SET: { 133 QCHECK(false) << "A non-set node cannot have children."; 134 break; 135 } 136 } 137} 138 139std::vector<int32> DecisionTree::GetChildren(const TreeNode& node) { 140 // A node's children depend on its type. 141 switch (node.node_case()) { 142 case TreeNode::kLeaf: { 143 return {}; 144 } 145 case TreeNode::kDenseFloatBinarySplit: { 146 const auto& split = node.dense_float_binary_split(); 147 return {split.left_id(), split.right_id()}; 148 } 149 case TreeNode::kSparseFloatBinarySplitDefaultLeft: { 150 const auto& split = node.sparse_float_binary_split_default_left().split(); 151 return {split.left_id(), split.right_id()}; 152 } 153 case TreeNode::kSparseFloatBinarySplitDefaultRight: { 154 const auto& split = 155 node.sparse_float_binary_split_default_right().split(); 156 return {split.left_id(), split.right_id()}; 157 } 158 case TreeNode::kCategoricalIdBinarySplit: { 159 const auto& split = node.categorical_id_binary_split(); 160 return {split.left_id(), split.right_id()}; 161 } 162 case TreeNode::NODE_NOT_SET: { 163 return {}; 164 } 165 } 166} 167 168} // namespace trees 169} // namespace boosted_trees 170} // namespace tensorflow 171