1// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14// =============================================================================
15#include "tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h"
16#include "tensorflow/contrib/boosted_trees/lib/utils/batch_features.h"
17#include "tensorflow/core/framework/tensor_testutil.h"
18#include "tensorflow/core/lib/core/status_test_util.h"
19#include "tensorflow/core/platform/test.h"
20
21namespace tensorflow {
22namespace boosted_trees {
23namespace trees {
24namespace {
25
26class DecisionTreeTest : public ::testing::Test {
27 protected:
28  DecisionTreeTest() : batch_features_(2) {
29    // Create a batch of two examples having one dense float, two sparse float
30    // and one sparse int features, and one sparse multi-column float feature
31    // (SparseFM).
32    // The first example is missing the second sparse feature column and the
33    // second example is missing the first sparse feature column.
34    // This looks like the following:
35    // Instance | DenseF1 | SparseF1 | SparseF2 | SparseI1 | SparseFM (3 cols)
36    // 0        |   7     |   -3     |          |    3     | 3.0 |   | 1.0
37    // 1        |  -2     |          |   4      |          | 1.5 |3.5|
38    auto dense_float_matrix = test::AsTensor<float>({7.0f, -2.0f}, {2, 1});
39    auto sparse_float_indices1 = test::AsTensor<int64>({0, 0}, {1, 2});
40    auto sparse_float_values1 = test::AsTensor<float>({-3.0f});
41    auto sparse_float_shape1 = test::AsTensor<int64>({2, 1});
42    auto sparse_float_indices2 = test::AsTensor<int64>({1, 0}, {1, 2});
43    auto sparse_float_values2 = test::AsTensor<float>({4.0f});
44    auto sparse_float_shape2 = test::AsTensor<int64>({2, 1});
45    auto sparse_int_indices1 = test::AsTensor<int64>({0, 0}, {1, 2});
46    auto sparse_int_values1 = test::AsTensor<int64>({3});
47    auto sparse_int_shape1 = test::AsTensor<int64>({2, 1});
48
49    // Multivalent sparse feature.
50    auto multi_sparse_float_indices =
51        test::AsTensor<int64>({0, 0, 0, 2, 1, 0, 1, 1}, {4, 2});
52    auto multi_sparse_float_values =
53        test::AsTensor<float>({3.0f, 1.0f, 1.5f, 3.5f});
54    auto multi_sparse_float_shape = test::AsTensor<int64>({2, 3});
55
56    TF_EXPECT_OK(batch_features_.Initialize(
57        {dense_float_matrix},
58        {sparse_float_indices1, sparse_float_indices2,
59         multi_sparse_float_indices},
60        {sparse_float_values1, sparse_float_values2, multi_sparse_float_values},
61        {sparse_float_shape1, sparse_float_shape2, multi_sparse_float_shape},
62        {sparse_int_indices1}, {sparse_int_values1}, {sparse_int_shape1}));
63  }
64
65  template <typename SplitType>
66  void TestLinkChildrenBinary(TreeNode* node, SplitType* split) {
67    // Verify children were linked.
68    DecisionTree::LinkChildren({3, 8}, node);
69    EXPECT_EQ(3, split->left_id());
70    EXPECT_EQ(8, split->right_id());
71
72    // Invalid cases.
73    EXPECT_DEATH(DecisionTree::LinkChildren({}, node),
74                 "A binary split node must have exactly two children.");
75    EXPECT_DEATH(DecisionTree::LinkChildren({3}, node),
76                 "A binary split node must have exactly two children.");
77    EXPECT_DEATH(DecisionTree::LinkChildren({1, 2, 3}, node),
78                 "A binary split node must have exactly two children.");
79  }
80
81  void TestGetChildren(const TreeNode& node,
82                       const std::vector<uint32>& expected_children) {
83    // Verify children were linked.
84    auto children = DecisionTree::GetChildren(node);
85    EXPECT_EQ(children.size(), expected_children.size());
86    for (size_t idx = 0; idx < children.size(); ++idx) {
87      EXPECT_EQ(children[idx], expected_children[idx]);
88    }
89  }
90
91  utils::BatchFeatures batch_features_;
92};
93
94TEST_F(DecisionTreeTest, TraverseEmpty) {
95  DecisionTreeConfig tree_config;
96  auto example = (*batch_features_.examples_iterable(0, 1).begin());
97  EXPECT_EQ(-1, DecisionTree::Traverse(tree_config, 0, example));
98}
99
100TEST_F(DecisionTreeTest, TraverseBias) {
101  DecisionTreeConfig tree_config;
102  tree_config.add_nodes()->mutable_leaf();
103  auto example = (*batch_features_.examples_iterable(0, 1).begin());
104  EXPECT_EQ(0, DecisionTree::Traverse(tree_config, 0, example));
105}
106
107TEST_F(DecisionTreeTest, TraverseInvalidSubRoot) {
108  DecisionTreeConfig tree_config;
109  tree_config.add_nodes()->mutable_leaf();
110  auto example = (*batch_features_.examples_iterable(0, 1).begin());
111  EXPECT_EQ(-1, DecisionTree::Traverse(tree_config, 10, example));
112}
113
114TEST_F(DecisionTreeTest, TraverseDenseBinarySplit) {
115  DecisionTreeConfig tree_config;
116  auto* split_node =
117      tree_config.add_nodes()->mutable_dense_float_binary_split();
118  split_node->set_feature_column(0);
119  split_node->set_threshold(0.0f);
120  split_node->set_left_id(1);
121  split_node->set_right_id(2);
122  tree_config.add_nodes()->mutable_leaf();
123  tree_config.add_nodes()->mutable_leaf();
124  auto example_iterable = batch_features_.examples_iterable(0, 2);
125
126  // Expect right child to be picked as !(7 <= 0);
127  auto example_it = example_iterable.begin();
128  EXPECT_EQ(2, DecisionTree::Traverse(tree_config, 0, *example_it));
129
130  // Expect left child to be picked as (-2 <= 0);
131  EXPECT_EQ(1, DecisionTree::Traverse(tree_config, 0, *++example_it));
132}
133
134TEST_F(DecisionTreeTest, TraverseSparseBinarySplit) {
135  auto example_iterable = batch_features_.examples_iterable(0, 2);
136  // Split on SparseF1.
137  // Test first sparse feature which is missing for the second example.
138  {
139    DecisionTreeConfig tree_config;
140    auto* split_node = tree_config.add_nodes()
141                           ->mutable_sparse_float_binary_split_default_left()
142                           ->mutable_split();
143    split_node->set_feature_column(0);
144    split_node->set_threshold(-20.0f);
145    split_node->set_left_id(1);
146    split_node->set_right_id(2);
147    tree_config.add_nodes()->mutable_leaf();
148    tree_config.add_nodes()->mutable_leaf();
149
150    // Expect right child to be picked as !(-3 <= -20).
151    auto example_it = example_iterable.begin();
152    EXPECT_EQ(2, DecisionTree::Traverse(tree_config, 0, *example_it));
153
154    // Expect left child to be picked as default direction.
155    EXPECT_EQ(1, DecisionTree::Traverse(tree_config, 0, *++example_it));
156  }
157  // Split on SparseF2.
158  // Test second sparse feature which is missing for the first example.
159  {
160    DecisionTreeConfig tree_config;
161    auto* split_node = tree_config.add_nodes()
162                           ->mutable_sparse_float_binary_split_default_right()
163                           ->mutable_split();
164    split_node->set_feature_column(1);
165    split_node->set_threshold(4.0f);
166    split_node->set_left_id(1);
167    split_node->set_right_id(2);
168    tree_config.add_nodes()->mutable_leaf();
169    tree_config.add_nodes()->mutable_leaf();
170
171    // Expect right child to be picked as default direction.
172    auto example_it = example_iterable.begin();
173    EXPECT_EQ(2, DecisionTree::Traverse(tree_config, 0, *example_it));
174
175    // Expect left child to be picked as (4 <= 4).
176    EXPECT_EQ(1, DecisionTree::Traverse(tree_config, 0, *++example_it));
177  }
178  // Split on SparseFM.
179  // Test second sparse feature which is missing for the first example.
180  {
181    DecisionTreeConfig tree_config;
182    auto* split_node = tree_config.add_nodes()
183                           ->mutable_sparse_float_binary_split_default_right()
184                           ->mutable_split();
185    split_node->set_feature_column(2);
186
187    split_node->set_left_id(1);
188    split_node->set_right_id(2);
189    tree_config.add_nodes()->mutable_leaf();
190    tree_config.add_nodes()->mutable_leaf();
191
192    // Split on first column
193    split_node->set_dimension_id(0);
194    split_node->set_threshold(2.0f);
195
196    // Both instances have this feature value.
197    auto example_it = example_iterable.begin();
198    EXPECT_EQ(2, DecisionTree::Traverse(tree_config, 0, *example_it));
199    EXPECT_EQ(1, DecisionTree::Traverse(tree_config, 0, *++example_it));
200
201    // Split on second column
202    split_node->set_dimension_id(1);
203    split_node->set_threshold(5.0f);
204
205    // First instance does not have it (default right), second does have it.
206    example_it = example_iterable.begin();
207    EXPECT_EQ(2, DecisionTree::Traverse(tree_config, 0, *example_it));
208    EXPECT_EQ(1, DecisionTree::Traverse(tree_config, 0, *++example_it));
209
210    // Split on third column
211    split_node->set_dimension_id(2);
212    split_node->set_threshold(3.0f);
213    example_it = example_iterable.begin();
214
215    // First instance has it, second does not (default right).
216    EXPECT_EQ(1, DecisionTree::Traverse(tree_config, 0, *example_it));
217    EXPECT_EQ(2, DecisionTree::Traverse(tree_config, 0, *++example_it));
218  }
219}
220
221TEST_F(DecisionTreeTest, TraverseCategoricalIdBinarySplit) {
222  DecisionTreeConfig tree_config;
223  auto* split_node =
224      tree_config.add_nodes()->mutable_categorical_id_binary_split();
225  split_node->set_feature_column(0);
226  split_node->set_feature_id(3);
227  split_node->set_left_id(1);
228  split_node->set_right_id(2);
229  tree_config.add_nodes()->mutable_leaf();
230  tree_config.add_nodes()->mutable_leaf();
231  auto example_iterable = batch_features_.examples_iterable(0, 2);
232
233  // Expect left child to be picked as 3 == 3;
234  auto example_it = example_iterable.begin();
235  EXPECT_EQ(1, DecisionTree::Traverse(tree_config, 0, *example_it));
236
237  // Expect right child to be picked as the feature is missing;
238  EXPECT_EQ(2, DecisionTree::Traverse(tree_config, 0, *++example_it));
239}
240
241TEST_F(DecisionTreeTest, TraverseCategoricalIdSetMembershipBinarySplit) {
242  DecisionTreeConfig tree_config;
243  auto* split_node = tree_config.add_nodes()
244                         ->mutable_categorical_id_set_membership_binary_split();
245  split_node->set_feature_column(0);
246  split_node->add_feature_ids(3);
247  split_node->set_left_id(1);
248  split_node->set_right_id(2);
249  tree_config.add_nodes()->mutable_leaf();
250  tree_config.add_nodes()->mutable_leaf();
251  auto example_iterable = batch_features_.examples_iterable(0, 2);
252
253  // Expect left child to be picked as 3 in {3};
254  auto example_it = example_iterable.begin();
255  EXPECT_EQ(1, DecisionTree::Traverse(tree_config, 0, *example_it));
256
257  // Expect right child to be picked as the feature is missing;
258  EXPECT_EQ(2, DecisionTree::Traverse(tree_config, 0, *++example_it));
259}
260
261TEST_F(DecisionTreeTest, TraverseHybridSplits) {
262  DecisionTreeConfig tree_config;
263  auto* split_node1 =
264      tree_config.add_nodes()->mutable_dense_float_binary_split();
265  split_node1->set_feature_column(0);
266  split_node1->set_threshold(9.0f);
267  split_node1->set_left_id(1);   // sparse split.
268  split_node1->set_right_id(2);  // leaf
269  auto* split_node2 = tree_config.add_nodes()
270                          ->mutable_sparse_float_binary_split_default_left()
271                          ->mutable_split();
272  tree_config.add_nodes()->mutable_leaf();
273  split_node2->set_feature_column(0);
274  split_node2->set_threshold(-20.0f);
275  split_node2->set_left_id(3);
276  split_node2->set_right_id(4);
277  auto* split_node3 =
278      tree_config.add_nodes()->mutable_categorical_id_binary_split();
279  split_node3->set_feature_column(0);
280  split_node3->set_feature_id(2);
281  split_node3->set_left_id(5);
282  split_node3->set_right_id(6);
283  tree_config.add_nodes()->mutable_leaf();
284  tree_config.add_nodes()->mutable_leaf();
285  tree_config.add_nodes()->mutable_leaf();
286  auto example_iterable = batch_features_.examples_iterable(0, 2);
287
288  // Expect will go left through the first dense split as (7.0f <= 9.0f),
289  // then will go right through the sparse split as !(-3 <= -20).
290  auto example_it = example_iterable.begin();
291  EXPECT_EQ(4, DecisionTree::Traverse(tree_config, 0, *example_it));
292
293  // Expect will go left through the first dense split as (-2.0f <= 9.0f),
294  // then will go left the default direction as the sparse feature is missing,
295  // then will go right as 2 != 3 on the categorical split.
296  EXPECT_EQ(6, DecisionTree::Traverse(tree_config, 0, *++example_it));
297}
298
299TEST_F(DecisionTreeTest, LinkChildrenLeaf) {
300  // Create leaf node.
301  TreeNode node;
302  node.mutable_leaf();
303
304  // No-op.
305  DecisionTree::LinkChildren({}, &node);
306
307  // Invalid case.
308  EXPECT_DEATH(DecisionTree::LinkChildren({1}, &node),
309               "A leaf node cannot have children.");
310}
311
312TEST_F(DecisionTreeTest, LinkChildrenDenseFloatBinarySplit) {
313  TreeNode node;
314  auto* split = node.mutable_dense_float_binary_split();
315  split->set_left_id(-1);
316  split->set_right_id(-1);
317  TestLinkChildrenBinary(&node, split);
318}
319
320TEST_F(DecisionTreeTest, LinkChildrenSparseFloatBinarySplitDefaultLeft) {
321  TreeNode node;
322  auto* split =
323      node.mutable_sparse_float_binary_split_default_left()->mutable_split();
324  split->set_left_id(-1);
325  split->set_right_id(-1);
326  TestLinkChildrenBinary(&node, split);
327}
328
329TEST_F(DecisionTreeTest, LinkChildrenSparseFloatBinarySplitDefaultRight) {
330  TreeNode node;
331  auto* split =
332      node.mutable_sparse_float_binary_split_default_right()->mutable_split();
333  split->set_left_id(-1);
334  split->set_right_id(-1);
335  TestLinkChildrenBinary(&node, split);
336}
337
338TEST_F(DecisionTreeTest, LinkChildrenCategoricalSingleIdBinarySplit) {
339  TreeNode node;
340  auto* split = node.mutable_categorical_id_binary_split();
341  split->set_left_id(-1);
342  split->set_right_id(-1);
343  TestLinkChildrenBinary(&node, split);
344}
345
346TEST_F(DecisionTreeTest, LinkChildrenNodeNotSet) {
347  // Create unset node.
348  TreeNode node;
349
350  // Invalid case.
351  EXPECT_DEATH(DecisionTree::LinkChildren({1}, &node),
352               "A non-set node cannot have children.");
353}
354
355TEST_F(DecisionTreeTest, GetChildrenLeaf) {
356  TreeNode node;
357  node.mutable_leaf();
358  TestGetChildren(node, {});
359}
360
361TEST_F(DecisionTreeTest, GetChildrenDenseFloatBinarySplit) {
362  TreeNode node;
363  auto* split = node.mutable_dense_float_binary_split();
364  split->set_left_id(23);
365  split->set_right_id(24);
366  TestGetChildren(node, {23, 24});
367}
368
369TEST_F(DecisionTreeTest, GetChildrenSparseFloatBinarySplitDefaultLeft) {
370  TreeNode node;
371  auto* split =
372      node.mutable_sparse_float_binary_split_default_left()->mutable_split();
373  split->set_left_id(12);
374  split->set_right_id(13);
375  TestGetChildren(node, {12, 13});
376}
377
378TEST_F(DecisionTreeTest, GetChildrenSparseFloatBinarySplitDefaultRight) {
379  TreeNode node;
380  auto* split =
381      node.mutable_sparse_float_binary_split_default_right()->mutable_split();
382  split->set_left_id(1);
383  split->set_right_id(2);
384  TestGetChildren(node, {1, 2});
385}
386
387TEST_F(DecisionTreeTest, GetChildrenCategoricalSingleIdBinarySplit) {
388  TreeNode node;
389  auto* split = node.mutable_categorical_id_binary_split();
390  split->set_left_id(7);
391  split->set_right_id(8);
392  TestGetChildren(node, {7, 8});
393}
394
395TEST_F(DecisionTreeTest, GetChildrenNodeNotSet) {
396  TreeNode node;
397  TestGetChildren(node, {});
398}
399
400}  // namespace
401}  // namespace trees
402}  // namespace boosted_trees
403}  // namespace tensorflow
404