1cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng// Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng// 3cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng// Licensed under the Apache License, Version 2.0 (the "License"); 4cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng// you may not use this file except in compliance with the License. 5cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng// You may obtain a copy of the License at 6cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng// 7cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng// http://www.apache.org/licenses/LICENSE-2.0 8cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng// 9cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng// Unless required by applicable law or agreed to in writing, software 10cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng// distributed under the License is distributed on an "AS IS" BASIS, 11cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng// See the License for the specific language governing permissions and 13cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng// limitations under the License. 14cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng// ============================================================================= 15cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng#include "tensorflow/contrib/boosted_trees/lib/learner/common/partitioners/example_partitioner.h" 16cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng#include "tensorflow/core/framework/tensor_testutil.h" 17cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng#include "tensorflow/core/lib/core/status_test_util.h" 18cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng#include "tensorflow/core/platform/test.h" 19cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng 20cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Fengnamespace tensorflow { 21cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Fengnamespace boosted_trees { 22cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Fengnamespace learner { 23cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Fengnamespace { 24cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng 25cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Fengclass ExamplePartitionerTest : public ::testing::Test { 26cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng protected: 27cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng ExamplePartitionerTest() 28cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng : thread_pool_(tensorflow::Env::Default(), "test_pool", 2), 29cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng batch_features_(2) { 30cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng dense_matrix_ = test::AsTensor<float>({7.0f, -2.0f}, {2, 1}); 31cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng TF_EXPECT_OK( 32cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng batch_features_.Initialize({dense_matrix_}, {}, {}, {}, {}, {}, {})); 33cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng } 34cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng 35cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng thread::ThreadPool thread_pool_; 36cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng Tensor dense_matrix_; 37cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng boosted_trees::utils::BatchFeatures batch_features_; 38cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng}; 39cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng 40cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei FengTEST_F(ExamplePartitionerTest, EmptyTree) { 41cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng boosted_trees::trees::DecisionTreeConfig tree_config; 42cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng std::vector<int32> example_partition_ids(2); 43cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng ExamplePartitioner::UpdatePartitions(tree_config, batch_features_, 1, 44cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng &thread_pool_, 45cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng example_partition_ids.data()); 46cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng EXPECT_EQ(0, example_partition_ids[0]); 47cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng EXPECT_EQ(0, example_partition_ids[1]); 48cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng} 49cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng 50cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei FengTEST_F(ExamplePartitionerTest, UpdatePartitions) { 51cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng // Create tree with one split. 52cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng // TODO(salehay): figure out if we can use PARSE_TEXT_PROTO. 53cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng boosted_trees::trees::DecisionTreeConfig tree_config; 54cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng auto* split = tree_config.add_nodes()->mutable_dense_float_binary_split(); 55cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng split->set_feature_column(0); 56cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng split->set_threshold(3.0f); 57cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng split->set_left_id(1); 58cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng split->set_right_id(2); 59cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng tree_config.add_nodes()->mutable_leaf(); 60cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng tree_config.add_nodes()->mutable_leaf(); 61cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng 62cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng // Partition input: 63cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng // Instance 1 has !(7 <= 3) => go right => leaf 2. 64cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng // Instance 2 has (-2 <= 3) => go left => leaf 1. 65cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng std::vector<int32> example_partition_ids(2); 66cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng ExamplePartitioner::UpdatePartitions(tree_config, batch_features_, 1, 67cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng &thread_pool_, 68cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng example_partition_ids.data()); 69cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng EXPECT_EQ(2, example_partition_ids[0]); 70cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng EXPECT_EQ(1, example_partition_ids[1]); 71cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng} 72cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng 73cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei FengTEST_F(ExamplePartitionerTest, PartitionExamples) { 74cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng // Create tree with one split. 75cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng // TODO(salehay): figure out if we can use PARSE_TEXT_PROTO. 76cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng boosted_trees::trees::DecisionTreeConfig tree_config; 77cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng auto* split = tree_config.add_nodes()->mutable_dense_float_binary_split(); 78cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng split->set_feature_column(0); 79cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng split->set_threshold(3.0f); 80cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng split->set_left_id(1); 81cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng split->set_right_id(2); 82cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng tree_config.add_nodes()->mutable_leaf(); 83cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng tree_config.add_nodes()->mutable_leaf(); 84cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng 85cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng // Partition input: 86cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng // Instance 1 has !(7 <= 3) => go right => leaf 2. 87cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng // Instance 2 has (-2 <= 3) => go left => leaf 1. 88cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng std::vector<int32> example_partition_ids(2); 89cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng ExamplePartitioner::PartitionExamples(tree_config, batch_features_, 1, 90cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng &thread_pool_, 91cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng example_partition_ids.data()); 92cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng EXPECT_EQ(2, example_partition_ids[0]); 93cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng EXPECT_EQ(1, example_partition_ids[1]); 94cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng} 95cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng 96cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng} // namespace 97cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng} // namespace learner 98cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng} // namespace boosted_trees 99cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng} // namespace tensorflow 100