1// Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14// ============================================================================= 15#include "tensorflow/contrib/boosted_trees/lib/learner/common/partitioners/example_partitioner.h" 16#include "tensorflow/contrib/boosted_trees/lib/utils/parallel_for.h" 17 18namespace tensorflow { 19namespace boosted_trees { 20namespace learner { 21 22void ExamplePartitioner::UpdatePartitions( 23 const boosted_trees::trees::DecisionTreeConfig& tree_config, 24 const boosted_trees::utils::BatchFeatures& features, 25 const int desired_parallelism, thread::ThreadPool* const thread_pool, 26 int32* example_partition_ids) { 27 // Get batch size. 28 const int64 batch_size = features.batch_size(); 29 if (batch_size <= 0) { 30 return; 31 } 32 33 // Lambda for doing a block of work. 34 auto partition_examples_subset = [&tree_config, &features, 35 &example_partition_ids](const int64 start, 36 const int64 end) { 37 if (TF_PREDICT_TRUE(tree_config.nodes_size() > 0)) { 38 auto examples_iterable = features.examples_iterable(start, end); 39 for (const auto& example : examples_iterable) { 40 int32& example_partition = example_partition_ids[example.example_idx]; 41 example_partition = boosted_trees::trees::DecisionTree::Traverse( 42 tree_config, example_partition, example); 43 DCHECK_GE(example_partition, 0); 44 } 45 } else { 46 std::fill(example_partition_ids + start, example_partition_ids + end, 0); 47 } 48 }; 49 50 // Parallelize partitioning over the batch. 51 boosted_trees::utils::ParallelFor(batch_size, desired_parallelism, 52 thread_pool, partition_examples_subset); 53} 54 55void ExamplePartitioner::PartitionExamples( 56 const boosted_trees::trees::DecisionTreeConfig& tree_config, 57 const boosted_trees::utils::BatchFeatures& features, 58 const int desired_parallelism, thread::ThreadPool* const thread_pool, 59 int32* example_partition_ids) { 60 // Get batch size. 61 const int64 batch_size = features.batch_size(); 62 if (batch_size <= 0) { 63 return; 64 } 65 66 // Lambda for doing a block of work. 67 auto partition_examples_subset = [&tree_config, &features, 68 &example_partition_ids](const int64 start, 69 const int64 end) { 70 if (TF_PREDICT_TRUE(tree_config.nodes_size() > 0)) { 71 auto examples_iterable = features.examples_iterable(start, end); 72 for (const auto& example : examples_iterable) { 73 uint32 partition = boosted_trees::trees::DecisionTree::Traverse( 74 tree_config, 0, example); 75 example_partition_ids[example.example_idx] = partition; 76 DCHECK_GE(partition, 0); 77 } 78 } else { 79 std::fill(example_partition_ids + start, example_partition_ids + end, 0); 80 } 81 }; 82 83 // Parallelize partitioning over the batch. 84 boosted_trees::utils::ParallelFor(batch_size, desired_parallelism, 85 thread_pool, partition_examples_subset); 86} 87 88} // namespace learner 89} // namespace boosted_trees 90} // namespace tensorflow 91