1// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14// =============================================================================
15#include "tensorflow/contrib/boosted_trees/lib/learner/common/partitioners/example_partitioner.h"
16#include "tensorflow/contrib/boosted_trees/lib/utils/parallel_for.h"
17
18namespace tensorflow {
19namespace boosted_trees {
20namespace learner {
21
22void ExamplePartitioner::UpdatePartitions(
23    const boosted_trees::trees::DecisionTreeConfig& tree_config,
24    const boosted_trees::utils::BatchFeatures& features,
25    const int desired_parallelism, thread::ThreadPool* const thread_pool,
26    int32* example_partition_ids) {
27  // Get batch size.
28  const int64 batch_size = features.batch_size();
29  if (batch_size <= 0) {
30    return;
31  }
32
33  // Lambda for doing a block of work.
34  auto partition_examples_subset = [&tree_config, &features,
35                                    &example_partition_ids](const int64 start,
36                                                            const int64 end) {
37    if (TF_PREDICT_TRUE(tree_config.nodes_size() > 0)) {
38      auto examples_iterable = features.examples_iterable(start, end);
39      for (const auto& example : examples_iterable) {
40        int32& example_partition = example_partition_ids[example.example_idx];
41        example_partition = boosted_trees::trees::DecisionTree::Traverse(
42            tree_config, example_partition, example);
43        DCHECK_GE(example_partition, 0);
44      }
45    } else {
46      std::fill(example_partition_ids + start, example_partition_ids + end, 0);
47    }
48  };
49
50  // Parallelize partitioning over the batch.
51  boosted_trees::utils::ParallelFor(batch_size, desired_parallelism,
52                                    thread_pool, partition_examples_subset);
53}
54
55void ExamplePartitioner::PartitionExamples(
56    const boosted_trees::trees::DecisionTreeConfig& tree_config,
57    const boosted_trees::utils::BatchFeatures& features,
58    const int desired_parallelism, thread::ThreadPool* const thread_pool,
59    int32* example_partition_ids) {
60  // Get batch size.
61  const int64 batch_size = features.batch_size();
62  if (batch_size <= 0) {
63    return;
64  }
65
66  // Lambda for doing a block of work.
67  auto partition_examples_subset = [&tree_config, &features,
68                                    &example_partition_ids](const int64 start,
69                                                            const int64 end) {
70    if (TF_PREDICT_TRUE(tree_config.nodes_size() > 0)) {
71      auto examples_iterable = features.examples_iterable(start, end);
72      for (const auto& example : examples_iterable) {
73        uint32 partition = boosted_trees::trees::DecisionTree::Traverse(
74            tree_config, 0, example);
75        example_partition_ids[example.example_idx] = partition;
76        DCHECK_GE(partition, 0);
77      }
78    } else {
79      std::fill(example_partition_ids + start, example_partition_ids + end, 0);
80    }
81  };
82
83  // Parallelize partitioning over the batch.
84  boosted_trees::utils::ParallelFor(batch_size, desired_parallelism,
85                                    thread_pool, partition_examples_subset);
86}
87
88}  // namespace learner
89}  // namespace boosted_trees
90}  // namespace tensorflow
91