1// Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14// ============================================================================= 15#include "tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h" 16#include "tensorflow/contrib/boosted_trees/lib/utils/batch_features.h" 17#include "tensorflow/core/framework/tensor_testutil.h" 18#include "tensorflow/core/lib/core/status_test_util.h" 19#include "tensorflow/core/platform/test.h" 20 21namespace tensorflow { 22namespace boosted_trees { 23namespace trees { 24namespace { 25 26class DecisionTreeTest : public ::testing::Test { 27 protected: 28 DecisionTreeTest() : batch_features_(2) { 29 // Create a batch of two examples having one dense float, two sparse float 30 // and one sparse int features, and one sparse multi-column float feature 31 // (SparseFM). 32 // The first example is missing the second sparse feature column and the 33 // second example is missing the first sparse feature column. 34 // This looks like the following: 35 // Instance | DenseF1 | SparseF1 | SparseF2 | SparseI1 | SparseFM (3 cols) 36 // 0 | 7 | -3 | | 3 | 3.0 | | 1.0 37 // 1 | -2 | | 4 | | 1.5 |3.5| 38 auto dense_float_matrix = test::AsTensor<float>({7.0f, -2.0f}, {2, 1}); 39 auto sparse_float_indices1 = test::AsTensor<int64>({0, 0}, {1, 2}); 40 auto sparse_float_values1 = test::AsTensor<float>({-3.0f}); 41 auto sparse_float_shape1 = test::AsTensor<int64>({2, 1}); 42 auto sparse_float_indices2 = test::AsTensor<int64>({1, 0}, {1, 2}); 43 auto sparse_float_values2 = test::AsTensor<float>({4.0f}); 44 auto sparse_float_shape2 = test::AsTensor<int64>({2, 1}); 45 auto sparse_int_indices1 = test::AsTensor<int64>({0, 0}, {1, 2}); 46 auto sparse_int_values1 = test::AsTensor<int64>({3}); 47 auto sparse_int_shape1 = test::AsTensor<int64>({2, 1}); 48 49 // Multivalent sparse feature. 50 auto multi_sparse_float_indices = 51 test::AsTensor<int64>({0, 0, 0, 2, 1, 0, 1, 1}, {4, 2}); 52 auto multi_sparse_float_values = 53 test::AsTensor<float>({3.0f, 1.0f, 1.5f, 3.5f}); 54 auto multi_sparse_float_shape = test::AsTensor<int64>({2, 3}); 55 56 TF_EXPECT_OK(batch_features_.Initialize( 57 {dense_float_matrix}, 58 {sparse_float_indices1, sparse_float_indices2, 59 multi_sparse_float_indices}, 60 {sparse_float_values1, sparse_float_values2, multi_sparse_float_values}, 61 {sparse_float_shape1, sparse_float_shape2, multi_sparse_float_shape}, 62 {sparse_int_indices1}, {sparse_int_values1}, {sparse_int_shape1})); 63 } 64 65 template <typename SplitType> 66 void TestLinkChildrenBinary(TreeNode* node, SplitType* split) { 67 // Verify children were linked. 68 DecisionTree::LinkChildren({3, 8}, node); 69 EXPECT_EQ(3, split->left_id()); 70 EXPECT_EQ(8, split->right_id()); 71 72 // Invalid cases. 73 EXPECT_DEATH(DecisionTree::LinkChildren({}, node), 74 "A binary split node must have exactly two children."); 75 EXPECT_DEATH(DecisionTree::LinkChildren({3}, node), 76 "A binary split node must have exactly two children."); 77 EXPECT_DEATH(DecisionTree::LinkChildren({1, 2, 3}, node), 78 "A binary split node must have exactly two children."); 79 } 80 81 void TestGetChildren(const TreeNode& node, 82 const std::vector<uint32>& expected_children) { 83 // Verify children were linked. 84 auto children = DecisionTree::GetChildren(node); 85 EXPECT_EQ(children.size(), expected_children.size()); 86 for (size_t idx = 0; idx < children.size(); ++idx) { 87 EXPECT_EQ(children[idx], expected_children[idx]); 88 } 89 } 90 91 utils::BatchFeatures batch_features_; 92}; 93 94TEST_F(DecisionTreeTest, TraverseEmpty) { 95 DecisionTreeConfig tree_config; 96 auto example = (*batch_features_.examples_iterable(0, 1).begin()); 97 EXPECT_EQ(-1, DecisionTree::Traverse(tree_config, 0, example)); 98} 99 100TEST_F(DecisionTreeTest, TraverseBias) { 101 DecisionTreeConfig tree_config; 102 tree_config.add_nodes()->mutable_leaf(); 103 auto example = (*batch_features_.examples_iterable(0, 1).begin()); 104 EXPECT_EQ(0, DecisionTree::Traverse(tree_config, 0, example)); 105} 106 107TEST_F(DecisionTreeTest, TraverseInvalidSubRoot) { 108 DecisionTreeConfig tree_config; 109 tree_config.add_nodes()->mutable_leaf(); 110 auto example = (*batch_features_.examples_iterable(0, 1).begin()); 111 EXPECT_EQ(-1, DecisionTree::Traverse(tree_config, 10, example)); 112} 113 114TEST_F(DecisionTreeTest, TraverseDenseBinarySplit) { 115 DecisionTreeConfig tree_config; 116 auto* split_node = 117 tree_config.add_nodes()->mutable_dense_float_binary_split(); 118 split_node->set_feature_column(0); 119 split_node->set_threshold(0.0f); 120 split_node->set_left_id(1); 121 split_node->set_right_id(2); 122 tree_config.add_nodes()->mutable_leaf(); 123 tree_config.add_nodes()->mutable_leaf(); 124 auto example_iterable = batch_features_.examples_iterable(0, 2); 125 126 // Expect right child to be picked as !(7 <= 0); 127 auto example_it = example_iterable.begin(); 128 EXPECT_EQ(2, DecisionTree::Traverse(tree_config, 0, *example_it)); 129 130 // Expect left child to be picked as (-2 <= 0); 131 EXPECT_EQ(1, DecisionTree::Traverse(tree_config, 0, *++example_it)); 132} 133 134TEST_F(DecisionTreeTest, TraverseSparseBinarySplit) { 135 auto example_iterable = batch_features_.examples_iterable(0, 2); 136 // Split on SparseF1. 137 // Test first sparse feature which is missing for the second example. 138 { 139 DecisionTreeConfig tree_config; 140 auto* split_node = tree_config.add_nodes() 141 ->mutable_sparse_float_binary_split_default_left() 142 ->mutable_split(); 143 split_node->set_feature_column(0); 144 split_node->set_threshold(-20.0f); 145 split_node->set_left_id(1); 146 split_node->set_right_id(2); 147 tree_config.add_nodes()->mutable_leaf(); 148 tree_config.add_nodes()->mutable_leaf(); 149 150 // Expect right child to be picked as !(-3 <= -20). 151 auto example_it = example_iterable.begin(); 152 EXPECT_EQ(2, DecisionTree::Traverse(tree_config, 0, *example_it)); 153 154 // Expect left child to be picked as default direction. 155 EXPECT_EQ(1, DecisionTree::Traverse(tree_config, 0, *++example_it)); 156 } 157 // Split on SparseF2. 158 // Test second sparse feature which is missing for the first example. 159 { 160 DecisionTreeConfig tree_config; 161 auto* split_node = tree_config.add_nodes() 162 ->mutable_sparse_float_binary_split_default_right() 163 ->mutable_split(); 164 split_node->set_feature_column(1); 165 split_node->set_threshold(4.0f); 166 split_node->set_left_id(1); 167 split_node->set_right_id(2); 168 tree_config.add_nodes()->mutable_leaf(); 169 tree_config.add_nodes()->mutable_leaf(); 170 171 // Expect right child to be picked as default direction. 172 auto example_it = example_iterable.begin(); 173 EXPECT_EQ(2, DecisionTree::Traverse(tree_config, 0, *example_it)); 174 175 // Expect left child to be picked as (4 <= 4). 176 EXPECT_EQ(1, DecisionTree::Traverse(tree_config, 0, *++example_it)); 177 } 178 // Split on SparseFM. 179 // Test second sparse feature which is missing for the first example. 180 { 181 DecisionTreeConfig tree_config; 182 auto* split_node = tree_config.add_nodes() 183 ->mutable_sparse_float_binary_split_default_right() 184 ->mutable_split(); 185 split_node->set_feature_column(2); 186 187 split_node->set_left_id(1); 188 split_node->set_right_id(2); 189 tree_config.add_nodes()->mutable_leaf(); 190 tree_config.add_nodes()->mutable_leaf(); 191 192 // Split on first column 193 split_node->set_dimension_id(0); 194 split_node->set_threshold(2.0f); 195 196 // Both instances have this feature value. 197 auto example_it = example_iterable.begin(); 198 EXPECT_EQ(2, DecisionTree::Traverse(tree_config, 0, *example_it)); 199 EXPECT_EQ(1, DecisionTree::Traverse(tree_config, 0, *++example_it)); 200 201 // Split on second column 202 split_node->set_dimension_id(1); 203 split_node->set_threshold(5.0f); 204 205 // First instance does not have it (default right), second does have it. 206 example_it = example_iterable.begin(); 207 EXPECT_EQ(2, DecisionTree::Traverse(tree_config, 0, *example_it)); 208 EXPECT_EQ(1, DecisionTree::Traverse(tree_config, 0, *++example_it)); 209 210 // Split on third column 211 split_node->set_dimension_id(2); 212 split_node->set_threshold(3.0f); 213 example_it = example_iterable.begin(); 214 215 // First instance has it, second does not (default right). 216 EXPECT_EQ(1, DecisionTree::Traverse(tree_config, 0, *example_it)); 217 EXPECT_EQ(2, DecisionTree::Traverse(tree_config, 0, *++example_it)); 218 } 219} 220 221TEST_F(DecisionTreeTest, TraverseCategoricalIdBinarySplit) { 222 DecisionTreeConfig tree_config; 223 auto* split_node = 224 tree_config.add_nodes()->mutable_categorical_id_binary_split(); 225 split_node->set_feature_column(0); 226 split_node->set_feature_id(3); 227 split_node->set_left_id(1); 228 split_node->set_right_id(2); 229 tree_config.add_nodes()->mutable_leaf(); 230 tree_config.add_nodes()->mutable_leaf(); 231 auto example_iterable = batch_features_.examples_iterable(0, 2); 232 233 // Expect left child to be picked as 3 == 3; 234 auto example_it = example_iterable.begin(); 235 EXPECT_EQ(1, DecisionTree::Traverse(tree_config, 0, *example_it)); 236 237 // Expect right child to be picked as the feature is missing; 238 EXPECT_EQ(2, DecisionTree::Traverse(tree_config, 0, *++example_it)); 239} 240 241TEST_F(DecisionTreeTest, TraverseCategoricalIdSetMembershipBinarySplit) { 242 DecisionTreeConfig tree_config; 243 auto* split_node = tree_config.add_nodes() 244 ->mutable_categorical_id_set_membership_binary_split(); 245 split_node->set_feature_column(0); 246 split_node->add_feature_ids(3); 247 split_node->set_left_id(1); 248 split_node->set_right_id(2); 249 tree_config.add_nodes()->mutable_leaf(); 250 tree_config.add_nodes()->mutable_leaf(); 251 auto example_iterable = batch_features_.examples_iterable(0, 2); 252 253 // Expect left child to be picked as 3 in {3}; 254 auto example_it = example_iterable.begin(); 255 EXPECT_EQ(1, DecisionTree::Traverse(tree_config, 0, *example_it)); 256 257 // Expect right child to be picked as the feature is missing; 258 EXPECT_EQ(2, DecisionTree::Traverse(tree_config, 0, *++example_it)); 259} 260 261TEST_F(DecisionTreeTest, TraverseHybridSplits) { 262 DecisionTreeConfig tree_config; 263 auto* split_node1 = 264 tree_config.add_nodes()->mutable_dense_float_binary_split(); 265 split_node1->set_feature_column(0); 266 split_node1->set_threshold(9.0f); 267 split_node1->set_left_id(1); // sparse split. 268 split_node1->set_right_id(2); // leaf 269 auto* split_node2 = tree_config.add_nodes() 270 ->mutable_sparse_float_binary_split_default_left() 271 ->mutable_split(); 272 tree_config.add_nodes()->mutable_leaf(); 273 split_node2->set_feature_column(0); 274 split_node2->set_threshold(-20.0f); 275 split_node2->set_left_id(3); 276 split_node2->set_right_id(4); 277 auto* split_node3 = 278 tree_config.add_nodes()->mutable_categorical_id_binary_split(); 279 split_node3->set_feature_column(0); 280 split_node3->set_feature_id(2); 281 split_node3->set_left_id(5); 282 split_node3->set_right_id(6); 283 tree_config.add_nodes()->mutable_leaf(); 284 tree_config.add_nodes()->mutable_leaf(); 285 tree_config.add_nodes()->mutable_leaf(); 286 auto example_iterable = batch_features_.examples_iterable(0, 2); 287 288 // Expect will go left through the first dense split as (7.0f <= 9.0f), 289 // then will go right through the sparse split as !(-3 <= -20). 290 auto example_it = example_iterable.begin(); 291 EXPECT_EQ(4, DecisionTree::Traverse(tree_config, 0, *example_it)); 292 293 // Expect will go left through the first dense split as (-2.0f <= 9.0f), 294 // then will go left the default direction as the sparse feature is missing, 295 // then will go right as 2 != 3 on the categorical split. 296 EXPECT_EQ(6, DecisionTree::Traverse(tree_config, 0, *++example_it)); 297} 298 299TEST_F(DecisionTreeTest, LinkChildrenLeaf) { 300 // Create leaf node. 301 TreeNode node; 302 node.mutable_leaf(); 303 304 // No-op. 305 DecisionTree::LinkChildren({}, &node); 306 307 // Invalid case. 308 EXPECT_DEATH(DecisionTree::LinkChildren({1}, &node), 309 "A leaf node cannot have children."); 310} 311 312TEST_F(DecisionTreeTest, LinkChildrenDenseFloatBinarySplit) { 313 TreeNode node; 314 auto* split = node.mutable_dense_float_binary_split(); 315 split->set_left_id(-1); 316 split->set_right_id(-1); 317 TestLinkChildrenBinary(&node, split); 318} 319 320TEST_F(DecisionTreeTest, LinkChildrenSparseFloatBinarySplitDefaultLeft) { 321 TreeNode node; 322 auto* split = 323 node.mutable_sparse_float_binary_split_default_left()->mutable_split(); 324 split->set_left_id(-1); 325 split->set_right_id(-1); 326 TestLinkChildrenBinary(&node, split); 327} 328 329TEST_F(DecisionTreeTest, LinkChildrenSparseFloatBinarySplitDefaultRight) { 330 TreeNode node; 331 auto* split = 332 node.mutable_sparse_float_binary_split_default_right()->mutable_split(); 333 split->set_left_id(-1); 334 split->set_right_id(-1); 335 TestLinkChildrenBinary(&node, split); 336} 337 338TEST_F(DecisionTreeTest, LinkChildrenCategoricalSingleIdBinarySplit) { 339 TreeNode node; 340 auto* split = node.mutable_categorical_id_binary_split(); 341 split->set_left_id(-1); 342 split->set_right_id(-1); 343 TestLinkChildrenBinary(&node, split); 344} 345 346TEST_F(DecisionTreeTest, LinkChildrenNodeNotSet) { 347 // Create unset node. 348 TreeNode node; 349 350 // Invalid case. 351 EXPECT_DEATH(DecisionTree::LinkChildren({1}, &node), 352 "A non-set node cannot have children."); 353} 354 355TEST_F(DecisionTreeTest, GetChildrenLeaf) { 356 TreeNode node; 357 node.mutable_leaf(); 358 TestGetChildren(node, {}); 359} 360 361TEST_F(DecisionTreeTest, GetChildrenDenseFloatBinarySplit) { 362 TreeNode node; 363 auto* split = node.mutable_dense_float_binary_split(); 364 split->set_left_id(23); 365 split->set_right_id(24); 366 TestGetChildren(node, {23, 24}); 367} 368 369TEST_F(DecisionTreeTest, GetChildrenSparseFloatBinarySplitDefaultLeft) { 370 TreeNode node; 371 auto* split = 372 node.mutable_sparse_float_binary_split_default_left()->mutable_split(); 373 split->set_left_id(12); 374 split->set_right_id(13); 375 TestGetChildren(node, {12, 13}); 376} 377 378TEST_F(DecisionTreeTest, GetChildrenSparseFloatBinarySplitDefaultRight) { 379 TreeNode node; 380 auto* split = 381 node.mutable_sparse_float_binary_split_default_right()->mutable_split(); 382 split->set_left_id(1); 383 split->set_right_id(2); 384 TestGetChildren(node, {1, 2}); 385} 386 387TEST_F(DecisionTreeTest, GetChildrenCategoricalSingleIdBinarySplit) { 388 TreeNode node; 389 auto* split = node.mutable_categorical_id_binary_split(); 390 split->set_left_id(7); 391 split->set_right_id(8); 392 TestGetChildren(node, {7, 8}); 393} 394 395TEST_F(DecisionTreeTest, GetChildrenNodeNotSet) { 396 TreeNode node; 397 TestGetChildren(node, {}); 398} 399 400} // namespace 401} // namespace trees 402} // namespace boosted_trees 403} // namespace tensorflow 404