1// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14// =============================================================================
15#include "tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h"
16#include "tensorflow/core/platform/macros.h"
17
18#include <algorithm>
19
20namespace tensorflow {
21namespace boosted_trees {
22namespace trees {
23
24constexpr int kInvalidLeaf = -1;
25int DecisionTree::Traverse(const DecisionTreeConfig& config,
26                           const int32 sub_root_id,
27                           const utils::Example& example) {
28  if (TF_PREDICT_FALSE(config.nodes_size() <= sub_root_id)) {
29    return kInvalidLeaf;
30  }
31
32  // Traverse tree starting at the provided sub-root.
33  int32 node_id = sub_root_id;
34  while (true) {
35    const auto& current_node = config.nodes(node_id);
36    switch (current_node.node_case()) {
37      case TreeNode::kLeaf: {
38        return node_id;
39      }
40      case TreeNode::kDenseFloatBinarySplit: {
41        const auto& split = current_node.dense_float_binary_split();
42        node_id = example.dense_float_features[split.feature_column()] <=
43                          split.threshold()
44                      ? split.left_id()
45                      : split.right_id();
46        break;
47      }
48      case TreeNode::kSparseFloatBinarySplitDefaultLeft: {
49        const auto& split =
50            current_node.sparse_float_binary_split_default_left().split();
51        auto sparse_feature =
52            example.sparse_float_features[split.feature_column()];
53        // Feature id for the split when multivalent sparse float column, or 0
54        // by default.
55        const int32 dimension_id = split.dimension_id();
56
57        node_id = !sparse_feature[dimension_id].has_value() ||
58                          sparse_feature[dimension_id].get_value() <=
59                              split.threshold()
60                      ? split.left_id()
61                      : split.right_id();
62        break;
63      }
64      case TreeNode::kSparseFloatBinarySplitDefaultRight: {
65        const auto& split =
66            current_node.sparse_float_binary_split_default_right().split();
67        auto sparse_feature =
68            example.sparse_float_features[split.feature_column()];
69        // Feature id for the split when multivalent sparse float column, or 0
70        // by default.
71        const int32 dimension_id = split.dimension_id();
72        node_id = sparse_feature[dimension_id].has_value() &&
73                          sparse_feature[dimension_id].get_value() <=
74                              split.threshold()
75                      ? split.left_id()
76                      : split.right_id();
77        break;
78      }
79      case TreeNode::kCategoricalIdBinarySplit: {
80        const auto& split = current_node.categorical_id_binary_split();
81        const auto& features =
82            example.sparse_int_features[split.feature_column()];
83        node_id = features.find(split.feature_id()) != features.end()
84                      ? split.left_id()
85                      : split.right_id();
86        break;
87      }
88      case TreeNode::kCategoricalIdSetMembershipBinarySplit: {
89        const auto& split =
90            current_node.categorical_id_set_membership_binary_split();
91        // The new node_id = left_id if a feature is found, or right_id.
92        node_id = split.right_id();
93        for (const int64 feature_id :
94             example.sparse_int_features[split.feature_column()]) {
95          if (std::binary_search(split.feature_ids().begin(),
96                                 split.feature_ids().end(), feature_id)) {
97            node_id = split.left_id();
98            break;
99          }
100        }
101        break;
102      }
103      case TreeNode::NODE_NOT_SET: {
104        LOG(QFATAL) << "Invalid node in tree: " << current_node.DebugString();
105        break;
106      }
107    }
108    DCHECK_NE(node_id, 0) << "Malformed tree, cycles found to root:"
109                          << current_node.DebugString();
110  }
111}
112
113void DecisionTree::LinkChildren(const std::vector<int32>& children,
114                                TreeNode* parent_node) {
115  // Decide how to link children depending on the parent node's type.
116  auto children_it = children.begin();
117  switch (parent_node->node_case()) {
118    case TreeNode::kLeaf: {
119      // Essentially no-op.
120      QCHECK(children.empty()) << "A leaf node cannot have children.";
121      break;
122    }
123    case TreeNode::kDenseFloatBinarySplit: {
124      QCHECK(children.size() == 2)
125          << "A binary split node must have exactly two children.";
126      auto* split = parent_node->mutable_dense_float_binary_split();
127      split->set_left_id(*children_it);
128      split->set_right_id(*++children_it);
129      break;
130    }
131    case TreeNode::kSparseFloatBinarySplitDefaultLeft: {
132      QCHECK(children.size() == 2)
133          << "A binary split node must have exactly two children.";
134      auto* split =
135          parent_node->mutable_sparse_float_binary_split_default_left()
136              ->mutable_split();
137      split->set_left_id(*children_it);
138      split->set_right_id(*++children_it);
139      break;
140    }
141    case TreeNode::kSparseFloatBinarySplitDefaultRight: {
142      QCHECK(children.size() == 2)
143          << "A binary split node must have exactly two children.";
144      auto* split =
145          parent_node->mutable_sparse_float_binary_split_default_right()
146              ->mutable_split();
147      split->set_left_id(*children_it);
148      split->set_right_id(*++children_it);
149      break;
150    }
151    case TreeNode::kCategoricalIdBinarySplit: {
152      QCHECK(children.size() == 2)
153          << "A binary split node must have exactly two children.";
154      auto* split = parent_node->mutable_categorical_id_binary_split();
155      split->set_left_id(*children_it);
156      split->set_right_id(*++children_it);
157      break;
158    }
159    case TreeNode::kCategoricalIdSetMembershipBinarySplit: {
160      QCHECK(children.size() == 2)
161          << "A binary split node must have exactly two children.";
162      auto* split =
163          parent_node->mutable_categorical_id_set_membership_binary_split();
164      split->set_left_id(*children_it);
165      split->set_right_id(*++children_it);
166      break;
167    }
168    case TreeNode::NODE_NOT_SET: {
169      LOG(QFATAL) << "A non-set node cannot have children.";
170      break;
171    }
172  }
173}
174
175std::vector<int32> DecisionTree::GetChildren(const TreeNode& node) {
176  // A node's children depend on its type.
177  switch (node.node_case()) {
178    case TreeNode::kLeaf: {
179      return {};
180    }
181    case TreeNode::kDenseFloatBinarySplit: {
182      const auto& split = node.dense_float_binary_split();
183      return {split.left_id(), split.right_id()};
184    }
185    case TreeNode::kSparseFloatBinarySplitDefaultLeft: {
186      const auto& split = node.sparse_float_binary_split_default_left().split();
187      return {split.left_id(), split.right_id()};
188    }
189    case TreeNode::kSparseFloatBinarySplitDefaultRight: {
190      const auto& split =
191          node.sparse_float_binary_split_default_right().split();
192      return {split.left_id(), split.right_id()};
193    }
194    case TreeNode::kCategoricalIdBinarySplit: {
195      const auto& split = node.categorical_id_binary_split();
196      return {split.left_id(), split.right_id()};
197    }
198    case TreeNode::kCategoricalIdSetMembershipBinarySplit: {
199      const auto& split = node.categorical_id_set_membership_binary_split();
200      return {split.left_id(), split.right_id()};
201    }
202    case TreeNode::NODE_NOT_SET: {
203      return {};
204    }
205  }
206}
207
208}  // namespace trees
209}  // namespace boosted_trees
210}  // namespace tensorflow
211