1// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14// =============================================================================
15#include "tensorflow/contrib/boosted_trees/lib/learner/common/partitioners/example_partitioner.h"
16#include "tensorflow/core/framework/tensor_testutil.h"
17#include "tensorflow/core/lib/core/status_test_util.h"
18#include "tensorflow/core/platform/test.h"
19
20namespace tensorflow {
21namespace boosted_trees {
22namespace learner {
23namespace {
24
25class ExamplePartitionerTest : public ::testing::Test {
26 protected:
27  ExamplePartitionerTest()
28      : thread_pool_(tensorflow::Env::Default(), "test_pool", 2),
29        batch_features_(2) {
30    dense_matrix_ = test::AsTensor<float>({7.0f, -2.0f}, {2, 1});
31    TF_EXPECT_OK(
32        batch_features_.Initialize({dense_matrix_}, {}, {}, {}, {}, {}, {}));
33  }
34
35  thread::ThreadPool thread_pool_;
36  Tensor dense_matrix_;
37  boosted_trees::utils::BatchFeatures batch_features_;
38};
39
40TEST_F(ExamplePartitionerTest, EmptyTree) {
41  boosted_trees::trees::DecisionTreeConfig tree_config;
42  std::vector<int32> example_partition_ids(2);
43  ExamplePartitioner::UpdatePartitions(tree_config, batch_features_, 1,
44                                       &thread_pool_,
45                                       example_partition_ids.data());
46  EXPECT_EQ(0, example_partition_ids[0]);
47  EXPECT_EQ(0, example_partition_ids[1]);
48}
49
50TEST_F(ExamplePartitionerTest, UpdatePartitions) {
51  // Create tree with one split.
52  // TODO(salehay): figure out if we can use PARSE_TEXT_PROTO.
53  boosted_trees::trees::DecisionTreeConfig tree_config;
54  auto* split = tree_config.add_nodes()->mutable_dense_float_binary_split();
55  split->set_feature_column(0);
56  split->set_threshold(3.0f);
57  split->set_left_id(1);
58  split->set_right_id(2);
59  tree_config.add_nodes()->mutable_leaf();
60  tree_config.add_nodes()->mutable_leaf();
61
62  // Partition input:
63  // Instance 1 has !(7 <= 3) => go right => leaf 2.
64  // Instance 2 has (-2 <= 3) => go left => leaf 1.
65  std::vector<int32> example_partition_ids(2);
66  ExamplePartitioner::UpdatePartitions(tree_config, batch_features_, 1,
67                                       &thread_pool_,
68                                       example_partition_ids.data());
69  EXPECT_EQ(2, example_partition_ids[0]);
70  EXPECT_EQ(1, example_partition_ids[1]);
71}
72
73TEST_F(ExamplePartitionerTest, PartitionExamples) {
74  // Create tree with one split.
75  // TODO(salehay): figure out if we can use PARSE_TEXT_PROTO.
76  boosted_trees::trees::DecisionTreeConfig tree_config;
77  auto* split = tree_config.add_nodes()->mutable_dense_float_binary_split();
78  split->set_feature_column(0);
79  split->set_threshold(3.0f);
80  split->set_left_id(1);
81  split->set_right_id(2);
82  tree_config.add_nodes()->mutable_leaf();
83  tree_config.add_nodes()->mutable_leaf();
84
85  // Partition input:
86  // Instance 1 has !(7 <= 3) => go right => leaf 2.
87  // Instance 2 has (-2 <= 3) => go left => leaf 1.
88  std::vector<int32> example_partition_ids(2);
89  ExamplePartitioner::PartitionExamples(tree_config, batch_features_, 1,
90                                        &thread_pool_,
91                                        example_partition_ids.data());
92  EXPECT_EQ(2, example_partition_ids[0]);
93  EXPECT_EQ(1, example_partition_ids[1]);
94}
95
96}  // namespace
97}  // namespace learner
98}  // namespace boosted_trees
99}  // namespace tensorflow
100