1// Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14// ============================================================================= 15#include "tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h" 16#include "tensorflow/core/platform/macros.h" 17 18#include <algorithm> 19 20namespace tensorflow { 21namespace boosted_trees { 22namespace trees { 23 24constexpr int kInvalidLeaf = -1; 25int DecisionTree::Traverse(const DecisionTreeConfig& config, 26 const int32 sub_root_id, 27 const utils::Example& example) { 28 if (TF_PREDICT_FALSE(config.nodes_size() <= sub_root_id)) { 29 return kInvalidLeaf; 30 } 31 32 // Traverse tree starting at the provided sub-root. 33 int32 node_id = sub_root_id; 34 while (true) { 35 const auto& current_node = config.nodes(node_id); 36 switch (current_node.node_case()) { 37 case TreeNode::kLeaf: { 38 return node_id; 39 } 40 case TreeNode::kDenseFloatBinarySplit: { 41 const auto& split = current_node.dense_float_binary_split(); 42 node_id = example.dense_float_features[split.feature_column()] <= 43 split.threshold() 44 ? split.left_id() 45 : split.right_id(); 46 break; 47 } 48 case TreeNode::kSparseFloatBinarySplitDefaultLeft: { 49 const auto& split = 50 current_node.sparse_float_binary_split_default_left().split(); 51 auto sparse_feature = 52 example.sparse_float_features[split.feature_column()]; 53 // Feature id for the split when multivalent sparse float column, or 0 54 // by default. 55 const int32 dimension_id = split.dimension_id(); 56 57 node_id = !sparse_feature[dimension_id].has_value() || 58 sparse_feature[dimension_id].get_value() <= 59 split.threshold() 60 ? split.left_id() 61 : split.right_id(); 62 break; 63 } 64 case TreeNode::kSparseFloatBinarySplitDefaultRight: { 65 const auto& split = 66 current_node.sparse_float_binary_split_default_right().split(); 67 auto sparse_feature = 68 example.sparse_float_features[split.feature_column()]; 69 // Feature id for the split when multivalent sparse float column, or 0 70 // by default. 71 const int32 dimension_id = split.dimension_id(); 72 node_id = sparse_feature[dimension_id].has_value() && 73 sparse_feature[dimension_id].get_value() <= 74 split.threshold() 75 ? split.left_id() 76 : split.right_id(); 77 break; 78 } 79 case TreeNode::kCategoricalIdBinarySplit: { 80 const auto& split = current_node.categorical_id_binary_split(); 81 const auto& features = 82 example.sparse_int_features[split.feature_column()]; 83 node_id = features.find(split.feature_id()) != features.end() 84 ? split.left_id() 85 : split.right_id(); 86 break; 87 } 88 case TreeNode::kCategoricalIdSetMembershipBinarySplit: { 89 const auto& split = 90 current_node.categorical_id_set_membership_binary_split(); 91 // The new node_id = left_id if a feature is found, or right_id. 92 node_id = split.right_id(); 93 for (const int64 feature_id : 94 example.sparse_int_features[split.feature_column()]) { 95 if (std::binary_search(split.feature_ids().begin(), 96 split.feature_ids().end(), feature_id)) { 97 node_id = split.left_id(); 98 break; 99 } 100 } 101 break; 102 } 103 case TreeNode::NODE_NOT_SET: { 104 LOG(QFATAL) << "Invalid node in tree: " << current_node.DebugString(); 105 break; 106 } 107 } 108 DCHECK_NE(node_id, 0) << "Malformed tree, cycles found to root:" 109 << current_node.DebugString(); 110 } 111} 112 113void DecisionTree::LinkChildren(const std::vector<int32>& children, 114 TreeNode* parent_node) { 115 // Decide how to link children depending on the parent node's type. 116 auto children_it = children.begin(); 117 switch (parent_node->node_case()) { 118 case TreeNode::kLeaf: { 119 // Essentially no-op. 120 QCHECK(children.empty()) << "A leaf node cannot have children."; 121 break; 122 } 123 case TreeNode::kDenseFloatBinarySplit: { 124 QCHECK(children.size() == 2) 125 << "A binary split node must have exactly two children."; 126 auto* split = parent_node->mutable_dense_float_binary_split(); 127 split->set_left_id(*children_it); 128 split->set_right_id(*++children_it); 129 break; 130 } 131 case TreeNode::kSparseFloatBinarySplitDefaultLeft: { 132 QCHECK(children.size() == 2) 133 << "A binary split node must have exactly two children."; 134 auto* split = 135 parent_node->mutable_sparse_float_binary_split_default_left() 136 ->mutable_split(); 137 split->set_left_id(*children_it); 138 split->set_right_id(*++children_it); 139 break; 140 } 141 case TreeNode::kSparseFloatBinarySplitDefaultRight: { 142 QCHECK(children.size() == 2) 143 << "A binary split node must have exactly two children."; 144 auto* split = 145 parent_node->mutable_sparse_float_binary_split_default_right() 146 ->mutable_split(); 147 split->set_left_id(*children_it); 148 split->set_right_id(*++children_it); 149 break; 150 } 151 case TreeNode::kCategoricalIdBinarySplit: { 152 QCHECK(children.size() == 2) 153 << "A binary split node must have exactly two children."; 154 auto* split = parent_node->mutable_categorical_id_binary_split(); 155 split->set_left_id(*children_it); 156 split->set_right_id(*++children_it); 157 break; 158 } 159 case TreeNode::kCategoricalIdSetMembershipBinarySplit: { 160 QCHECK(children.size() == 2) 161 << "A binary split node must have exactly two children."; 162 auto* split = 163 parent_node->mutable_categorical_id_set_membership_binary_split(); 164 split->set_left_id(*children_it); 165 split->set_right_id(*++children_it); 166 break; 167 } 168 case TreeNode::NODE_NOT_SET: { 169 LOG(QFATAL) << "A non-set node cannot have children."; 170 break; 171 } 172 } 173} 174 175std::vector<int32> DecisionTree::GetChildren(const TreeNode& node) { 176 // A node's children depend on its type. 177 switch (node.node_case()) { 178 case TreeNode::kLeaf: { 179 return {}; 180 } 181 case TreeNode::kDenseFloatBinarySplit: { 182 const auto& split = node.dense_float_binary_split(); 183 return {split.left_id(), split.right_id()}; 184 } 185 case TreeNode::kSparseFloatBinarySplitDefaultLeft: { 186 const auto& split = node.sparse_float_binary_split_default_left().split(); 187 return {split.left_id(), split.right_id()}; 188 } 189 case TreeNode::kSparseFloatBinarySplitDefaultRight: { 190 const auto& split = 191 node.sparse_float_binary_split_default_right().split(); 192 return {split.left_id(), split.right_id()}; 193 } 194 case TreeNode::kCategoricalIdBinarySplit: { 195 const auto& split = node.categorical_id_binary_split(); 196 return {split.left_id(), split.right_id()}; 197 } 198 case TreeNode::kCategoricalIdSetMembershipBinarySplit: { 199 const auto& split = node.categorical_id_set_membership_binary_split(); 200 return {split.left_id(), split.right_id()}; 201 } 202 case TreeNode::NODE_NOT_SET: { 203 return {}; 204 } 205 } 206} 207 208} // namespace trees 209} // namespace boosted_trees 210} // namespace tensorflow 211