1cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng//
3cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng// Licensed under the Apache License, Version 2.0 (the "License");
4cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng// you may not use this file except in compliance with the License.
5cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng// You may obtain a copy of the License at
6cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng//
7cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng//     http://www.apache.org/licenses/LICENSE-2.0
8cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng//
9cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng// Unless required by applicable law or agreed to in writing, software
10cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng// distributed under the License is distributed on an "AS IS" BASIS,
11cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng// See the License for the specific language governing permissions and
13cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng// limitations under the License.
14cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng// =============================================================================
15cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng#include "tensorflow/contrib/boosted_trees/lib/learner/common/partitioners/example_partitioner.h"
16cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng#include "tensorflow/core/framework/tensor_testutil.h"
17cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng#include "tensorflow/core/lib/core/status_test_util.h"
18cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng#include "tensorflow/core/platform/test.h"
19cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng
20cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Fengnamespace tensorflow {
21cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Fengnamespace boosted_trees {
22cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Fengnamespace learner {
23cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Fengnamespace {
24cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng
25cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Fengclass ExamplePartitionerTest : public ::testing::Test {
26cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng protected:
27cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng  ExamplePartitionerTest()
28cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng      : thread_pool_(tensorflow::Env::Default(), "test_pool", 2),
29cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng        batch_features_(2) {
30cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng    dense_matrix_ = test::AsTensor<float>({7.0f, -2.0f}, {2, 1});
31cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng    TF_EXPECT_OK(
32cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng        batch_features_.Initialize({dense_matrix_}, {}, {}, {}, {}, {}, {}));
33cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng  }
34cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng
35cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng  thread::ThreadPool thread_pool_;
36cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng  Tensor dense_matrix_;
37cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng  boosted_trees::utils::BatchFeatures batch_features_;
38cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng};
39cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng
40cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei FengTEST_F(ExamplePartitionerTest, EmptyTree) {
41cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng  boosted_trees::trees::DecisionTreeConfig tree_config;
42cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng  std::vector<int32> example_partition_ids(2);
43cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng  ExamplePartitioner::UpdatePartitions(tree_config, batch_features_, 1,
44cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng                                       &thread_pool_,
45cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng                                       example_partition_ids.data());
46cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng  EXPECT_EQ(0, example_partition_ids[0]);
47cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng  EXPECT_EQ(0, example_partition_ids[1]);
48cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng}
49cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng
50cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei FengTEST_F(ExamplePartitionerTest, UpdatePartitions) {
51cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng  // Create tree with one split.
52cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng  // TODO(salehay): figure out if we can use PARSE_TEXT_PROTO.
53cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng  boosted_trees::trees::DecisionTreeConfig tree_config;
54cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng  auto* split = tree_config.add_nodes()->mutable_dense_float_binary_split();
55cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng  split->set_feature_column(0);
56cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng  split->set_threshold(3.0f);
57cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng  split->set_left_id(1);
58cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng  split->set_right_id(2);
59cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng  tree_config.add_nodes()->mutable_leaf();
60cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng  tree_config.add_nodes()->mutable_leaf();
61cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng
62cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng  // Partition input:
63cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng  // Instance 1 has !(7 <= 3) => go right => leaf 2.
64cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng  // Instance 2 has (-2 <= 3) => go left => leaf 1.
65cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng  std::vector<int32> example_partition_ids(2);
66cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng  ExamplePartitioner::UpdatePartitions(tree_config, batch_features_, 1,
67cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng                                       &thread_pool_,
68cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng                                       example_partition_ids.data());
69cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng  EXPECT_EQ(2, example_partition_ids[0]);
70cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng  EXPECT_EQ(1, example_partition_ids[1]);
71cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng}
72cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng
73cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei FengTEST_F(ExamplePartitionerTest, PartitionExamples) {
74cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng  // Create tree with one split.
75cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng  // TODO(salehay): figure out if we can use PARSE_TEXT_PROTO.
76cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng  boosted_trees::trees::DecisionTreeConfig tree_config;
77cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng  auto* split = tree_config.add_nodes()->mutable_dense_float_binary_split();
78cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng  split->set_feature_column(0);
79cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng  split->set_threshold(3.0f);
80cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng  split->set_left_id(1);
81cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng  split->set_right_id(2);
82cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng  tree_config.add_nodes()->mutable_leaf();
83cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng  tree_config.add_nodes()->mutable_leaf();
84cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng
85cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng  // Partition input:
86cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng  // Instance 1 has !(7 <= 3) => go right => leaf 2.
87cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng  // Instance 2 has (-2 <= 3) => go left => leaf 1.
88cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng  std::vector<int32> example_partition_ids(2);
89cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng  ExamplePartitioner::PartitionExamples(tree_config, batch_features_, 1,
90cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng                                        &thread_pool_,
91cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng                                        example_partition_ids.data());
92cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng  EXPECT_EQ(2, example_partition_ids[0]);
93cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng  EXPECT_EQ(1, example_partition_ids[1]);
94cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng}
95cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng
96cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng}  // namespace
97cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng}  // namespace learner
98cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng}  // namespace boosted_trees
99cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng}  // namespace tensorflow
100