1// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14// =============================================================================
15#include <algorithm>
16#include <iterator>
17#include <string>
18#include <vector>
19
20#include "tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream.h"
21#include "tensorflow/contrib/boosted_trees/lib/utils/parallel_for.h"
22#include "tensorflow/contrib/boosted_trees/lib/utils/tensor_utils.h"
23#include "tensorflow/contrib/boosted_trees/proto/quantiles.pb.h"
24#include "tensorflow/contrib/boosted_trees/resources/quantile_stream_resource.h"
25#include "tensorflow/core/framework/op_kernel.h"
26#include "tensorflow/core/framework/resource_mgr.h"
27#include "tensorflow/core/framework/tensor.h"
28#include "tensorflow/core/framework/tensor_shape.h"
29#include "tensorflow/core/framework/types.h"
30#include "tensorflow/core/lib/core/errors.h"
31#include "tensorflow/core/lib/core/status.h"
32#include "tensorflow/core/lib/strings/stringprintf.h"
33#include "tensorflow/core/platform/types.h"
34#include "tensorflow/core/util/work_sharder.h"
35
36namespace tensorflow {
37
38using ::boosted_trees::QuantileConfig;
39using boosted_trees::QuantileStreamResource;
40using boosted_trees::utils::TensorUtils;
41
42namespace {
43const char* const kExampleWeightsName = "example_weights";
44const char* const kMaxElementsName = "max_elements";
45const char* const kNextStampTokenName = "next_stamp_token";
46const char* const kStampTokenName = "stamp_token";
47const char* const kAreBucketsReadyName = "are_buckets_ready";
48const char* const kGenerateQuantiles = "generate_quantiles";
49// Names for sparse arguments.
50const char* const kNumSparseFeaturesName = "num_sparse_features";
51const char* const kSparseBucketsName = "sparse_buckets";
52const char* const kSparseValuesName = "sparse_values";
53const char* const kSparseIndicesName = "sparse_indices";
54const char* const kSparseSummariesName = "sparse_summaries";
55const char* const kSparseConfigName = "sparse_config";
56const char* const kSparseOutputTensorName = "sparse_quantiles";
57// Names for dense arguments.
58const char* const kDenseBucketsName = "dense_buckets";
59const char* const kDenseConfigName = "dense_config";
60const char* const kDenseOutputTensorName = "dense_quantiles";
61const char* const kDenseSummariesName = "dense_summaries";
62const char* const kDenseValuesName = "dense_values";
63const char* const kNumDenseFeaturesName = "num_dense_features";
64const char* const kResourceHandlesName = "quantile_accumulator_handles";
65const char* const kNumQuantilesName = "num_quantiles";
66const char* const kEpsilonName = "epsilon";
67const char* const kBucketsName = "buckets";
68const char* const kStreamStateName = "stream_state";
69const char* const kSummariesName = "summaries";
70
71using QuantileStream =
72    boosted_trees::quantiles::WeightedQuantilesStream<float, float>;
73using QuantileSummary =
74    boosted_trees::quantiles::WeightedQuantilesSummary<float, float>;
75using QuantileSummaryEntry =
76    boosted_trees::quantiles::WeightedQuantilesSummary<float,
77                                                       float>::SummaryEntry;
78
79std::vector<float> GetBuckets(const int32 feature,
80                              const OpInputList& buckets_list) {
81  const auto& buckets = buckets_list[feature].flat<float>();
82  std::vector<float> buckets_vector(buckets.data(),
83                                    buckets.data() + buckets.size());
84  return buckets_vector;
85}
86
87int32 GetFeatureDimension(const int32 feature_index, const int64 instance,
88                          const OpInputList* const indices_list) {
89  if (indices_list != nullptr) {
90    // Sparse multidimensional.
91    return (*indices_list)[feature_index].matrix<int64>()(instance, 1);
92  }
93  // No indices, assume one-dimensional tensor.
94  return 0;
95}
96
97// Allows quantization for each of multiple dimensions of a sparse feature.
98void QuantizeFeatures(
99    const string& output_name, const OpInputList& values_list,
100    const OpInputList& buckets_list,
101    const OpInputList* const
102        indices_list /** Optional, provide for sparse features **/,
103    OpKernelContext* const context) {
104  if (values_list.size() == 0) {
105    return;
106  }
107  OpOutputList output_list;
108  OP_REQUIRES_OK(context, context->output_list(output_name, &output_list));
109
110  for (int32 feature_index = 0; feature_index < values_list.size();
111       ++feature_index) {
112    const Tensor& values_tensor = values_list[feature_index];
113    const int64 num_values = values_tensor.dim_size(0);
114
115    Tensor* output_t = nullptr;
116    // Output will have bucket id and dimension of the features for that bucket.
117    OP_REQUIRES_OK(
118        context, output_list.allocate(feature_index,
119                                      TensorShape({num_values, 2}), &output_t));
120
121    auto output = output_t->matrix<int32>();
122
123    const std::vector<float>& buckets_vector =
124        GetBuckets(feature_index, buckets_list);
125    auto flat_values = values_tensor.flat<float>();
126    for (int64 instance = 0; instance < num_values; ++instance) {
127      const float value = flat_values(instance);
128      auto bucket_iter =
129          std::lower_bound(buckets_vector.begin(), buckets_vector.end(), value);
130      if (bucket_iter == buckets_vector.end()) {
131        --bucket_iter;
132      }
133      const int32 bucket =
134          static_cast<int32>(bucket_iter - buckets_vector.begin());
135      // Bucket id.
136      output(instance, 0) = bucket;
137      // Dimension.
138      output(instance, 1) =
139          GetFeatureDimension(feature_index, instance, indices_list);
140    }
141  }
142}
143
144// Validates attributes for the quantile ops.
145Status ReadAndValidateAttributes(OpKernelConstruction* const context,
146                                 int* num_dense_features,
147                                 int* num_sparse_features) {
148  TF_RETURN_IF_ERROR(
149      context->GetAttr(kNumDenseFeaturesName, num_dense_features));
150  TF_RETURN_IF_ERROR(
151      context->GetAttr(kNumSparseFeaturesName, num_sparse_features));
152  if ((*num_dense_features) + (*num_sparse_features) == 0) {
153    return errors::InvalidArgument(
154        "Please provide at least sparse or dense features.");
155  }
156  return Status::OK();
157}
158
159void ParseConfig(OpKernelConstruction* const context, const string& name,
160                 std::vector<QuantileConfig>* output) {
161  std::vector<string> serialized_config;
162  OP_REQUIRES_OK(context, context->GetAttr(name, &serialized_config));
163  output->reserve(serialized_config.size());
164  QuantileConfig tmp;
165  for (const auto& serialized_string : serialized_config) {
166    OP_REQUIRES(context, tmp.ParseFromString(serialized_string),
167                errors::InvalidArgument("Malformed QuantileConfig passed in."));
168    output->push_back(tmp);
169  }
170}
171
172// Generates quantiles on a finalized QuantileStream.
173std::vector<float> GenerateBoundaries(const QuantileStream& stream,
174                                      int num_boundaries) {
175  std::vector<float> boundaries = stream.GenerateBoundaries(num_boundaries);
176
177  // Uniquify elements as we may get dupes.
178  auto end_it = std::unique(boundaries.begin(), boundaries.end());
179  boundaries.resize(std::distance(boundaries.begin(), end_it));
180  return boundaries;
181}
182
183// Generates quantiles on a finalized QuantileStream.
184std::vector<float> GenerateQuantiles(const QuantileStream& stream,
185                                     int num_quantiles) {
186  // Do not de-dup boundaries. Exactly num_quantiles+1 boundary values
187  // will be returned.
188  std::vector<float> boundaries = stream.GenerateQuantiles(num_quantiles);
189  CHECK_EQ(boundaries.size(), num_quantiles + 1);
190  return boundaries;
191}
192
193// Copies quantiles to output list.
194void CopyBoundaries(OpKernelContext* const context,
195                    const std::vector<float>& boundaries, const int64 index,
196                    OpOutputList* output_list) {
197  // Output to tensor.
198  Tensor* output_t = nullptr;
199  OP_REQUIRES_OK(
200      context, output_list->allocate(
201                   index, {static_cast<int64>(boundaries.size())}, &output_t));
202  auto* quantiles_flat = output_t->flat<float>().data();
203  memcpy(quantiles_flat, boundaries.data(), sizeof(float) * boundaries.size());
204}
205
206void CopySummaryToProto(const QuantileSummary& summary,
207                        ::boosted_trees::QuantileSummaryState* summary_proto) {
208  summary_proto->mutable_entries()->Reserve(summary.Size());
209  for (const auto& entry : summary.GetEntryList()) {
210    auto* new_entry = summary_proto->add_entries();
211    new_entry->set_value(entry.value);
212    new_entry->set_weight(entry.weight);
213    new_entry->set_min_rank(entry.min_rank);
214    new_entry->set_max_rank(entry.max_rank);
215  }
216}
217
218}  // namespace
219
220// Accumulator for Quantile Summaries.
221REGISTER_RESOURCE_HANDLE_KERNEL(QuantileStreamResource);
222
223REGISTER_KERNEL_BUILDER(
224    Name("QuantileAccumulatorIsInitialized").Device(DEVICE_CPU),
225    IsResourceInitialized<QuantileStreamResource>);
226
227class CreateQuantileAccumulatorOp : public OpKernel {
228 public:
229  explicit CreateQuantileAccumulatorOp(OpKernelConstruction* const context)
230      : OpKernel(context) {
231    OP_REQUIRES_OK(context, context->GetAttr(kEpsilonName, &epsilon_));
232    OP_REQUIRES_OK(context,
233                   context->GetAttr(kNumQuantilesName, &num_quantiles_));
234    OP_REQUIRES_OK(context, context->GetAttr(kMaxElementsName, &max_elements_));
235    OP_REQUIRES_OK(context,
236                   context->GetAttr(kGenerateQuantiles, &generate_quantiles_));
237  }
238
239  void Compute(OpKernelContext* context) override {
240    // Only create one, if one does not exist already. Report status for all
241    // other exceptions. If one already exists, it unrefs the new one.
242    const Tensor* stamp_token_t;
243    OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
244    auto result = new QuantileStreamResource(epsilon_, num_quantiles_,
245                                             max_elements_, generate_quantiles_,
246                                             stamp_token_t->scalar<int64>()());
247    auto status = CreateResource(context, HandleFromInput(context, 0), result);
248    if (!status.ok() && status.code() != tensorflow::error::ALREADY_EXISTS) {
249      OP_REQUIRES(context, false, status);
250    }
251  }
252
253 private:
254  float epsilon_;
255  int32 num_quantiles_;
256  // An upperbound on the number of enteries that the summaries might have
257  // for a feature.
258  int64 max_elements_;
259  bool generate_quantiles_;
260};
261
262REGISTER_KERNEL_BUILDER(Name("CreateQuantileAccumulator").Device(DEVICE_CPU),
263                        CreateQuantileAccumulatorOp);
264
265// Adds a summary to the quantile summary stream.
266class QuantileAccumulatorAddSummariesOp : public OpKernel {
267 public:
268  explicit QuantileAccumulatorAddSummariesOp(
269      OpKernelConstruction* const context)
270      : OpKernel(context) {}
271
272  void Compute(OpKernelContext* context) override {
273    OpInputList resource_handle_list;
274    OP_REQUIRES_OK(context, context->input_list(kResourceHandlesName,
275                                                &resource_handle_list));
276    OpInputList summary_list;
277    OP_REQUIRES_OK(context, context->input_list(kSummariesName, &summary_list));
278
279    const Tensor* stamp_token_t;
280    OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
281    int64 stamp_token = stamp_token_t->scalar<int64>()();
282
283    thread::ThreadPool* const worker_threads =
284        context->device()->tensorflow_cpu_worker_threads()->workers;
285    boosted_trees::utils::ParallelFor(
286        resource_handle_list.size(), worker_threads->NumThreads(),
287        worker_threads,
288        [&context, &resource_handle_list, &summary_list, stamp_token](
289            int64 start, int64 end) {
290          for (int resource_handle_idx = start; resource_handle_idx < end;
291               ++resource_handle_idx) {
292            ResourceHandle handle = resource_handle_list[resource_handle_idx]
293                                        .flat<ResourceHandle>()(0);
294            QuantileStreamResource* streams_resource;
295            // Create a reference to the underlying resource using the handle.
296            OP_REQUIRES_OK(context,
297                           LookupResource(context, handle, &streams_resource));
298            // Remove the reference at the end of this scope.
299            mutex_lock l(*streams_resource->mutex());
300            core::ScopedUnref unref_me(streams_resource);
301
302            // If the stamp is invalid we drop the update.
303            if (!streams_resource->is_stamp_valid(stamp_token)) {
304              VLOG(1)
305                  << "Invalid stamp token in QuantileAccumulatorAddSummariesOp."
306                  << " Passed stamp token: " << stamp_token << " "
307                  << "Current token: " << streams_resource->stamp();
308              return;
309            }
310
311            protobuf::Arena arena;
312            ::boosted_trees::QuantileSummaryState* summary_proto =
313                protobuf::Arena::CreateMessage<
314                    ::boosted_trees::QuantileSummaryState>(&arena);
315            OP_REQUIRES(
316                context,
317                ParseProtoUnlimited(
318                    summary_proto,
319                    summary_list[resource_handle_idx].scalar<string>()()),
320                errors::InvalidArgument("Unable to parse quantile summary."));
321            std::vector<QuantileSummaryEntry> entries;
322            entries.reserve(summary_proto->entries_size());
323            for (const auto& entry : summary_proto->entries()) {
324              entries.emplace_back(entry.value(), entry.weight(),
325                                   entry.min_rank(), entry.max_rank());
326            }
327
328            // Add the summary to the quantile stream.
329            streams_resource->stream(stamp_token)->PushSummary(entries);
330          }
331        });
332  }
333};
334
335REGISTER_KERNEL_BUILDER(
336    Name("QuantileAccumulatorAddSummaries").Device(DEVICE_CPU),
337    QuantileAccumulatorAddSummariesOp);
338
339// Generates summaries for given set of float values, and the given config.
340class MakeQuantileSummariesOp : public OpKernel {
341 public:
342  explicit MakeQuantileSummariesOp(OpKernelConstruction* const context)
343      : OpKernel(context) {
344    OP_REQUIRES_OK(context,
345                   ReadAndValidateAttributes(context, &num_dense_features_,
346                                             &num_sparse_features_));
347    OP_REQUIRES_OK(context, context->GetAttr(kEpsilonName, &epsilon_));
348  }
349
350  void Compute(OpKernelContext* const context) override {
351    // Read dense float features list;
352    OpInputList dense_float_features_list;
353    OP_REQUIRES_OK(context, TensorUtils::ReadDenseFloatFeatures(
354                                context, &dense_float_features_list));
355
356    // Read sparse float features list;
357    OpInputList sparse_float_feature_indices_list;
358    OpInputList sparse_float_feature_values_list;
359    OpInputList sparse_float_feature_shapes_list;
360    OP_REQUIRES_OK(context, TensorUtils::ReadSparseFloatFeatures(
361                                context, &sparse_float_feature_indices_list,
362                                &sparse_float_feature_values_list,
363                                &sparse_float_feature_shapes_list));
364
365    // Parse example weights and get batch size.
366    const Tensor* example_weights_t;
367    OP_REQUIRES_OK(context,
368                   context->input(kExampleWeightsName, &example_weights_t));
369    auto example_weights = example_weights_t->flat<float>();
370    const int64 batch_size = example_weights.size();
371
372    OpOutputList sparse_summaries_output_list;
373    OP_REQUIRES_OK(context,
374                   context->output_list(kSparseSummariesName,
375                                        &sparse_summaries_output_list));
376    OpOutputList dense_summaries_output_list;
377    OP_REQUIRES_OK(context, context->output_list(kDenseSummariesName,
378                                                 &dense_summaries_output_list));
379
380    auto do_quantile_summary_gen = [&](const int64 begin, const int64 end) {
381      auto copy_over_summaries = [&](const QuantileStream& stream,
382                                     const int64 index,
383                                     OpOutputList* output_list) {
384        protobuf::Arena arena;
385        ::boosted_trees::QuantileSummaryState* summary_proto =
386            protobuf::Arena::CreateMessage<
387                ::boosted_trees::QuantileSummaryState>(&arena);
388        const auto& summary = stream.GetFinalSummary();
389        CopySummaryToProto(summary, summary_proto);
390        // Output to tensor.
391        Tensor* output_t = nullptr;
392        OP_REQUIRES_OK(context, output_list->allocate(index, {}, &output_t));
393        summary_proto->SerializeToString(&output_t->scalar<string>()());
394      };
395
396      // These are blocks of ranges. We are iterating over both sparse and
397      // dense features i.e. [0, sparse_features.size() + dense_features.size()]
398      for (int64 i = begin; i < end; ++i) {
399        if (i < num_dense_features_) {
400          const int64 dense_index = i;
401          const auto dense_values =
402              dense_float_features_list[dense_index].flat<float>();
403          QuantileStream stream(epsilon_, batch_size + 1);
404          // Run quantile summary generation.
405          for (int64 j = 0; j < batch_size; ++j) {
406            stream.PushEntry(dense_values(j), example_weights(j));
407          }
408          stream.Finalize();
409          // Copy summaries to output.
410          copy_over_summaries(stream, dense_index,
411                              &dense_summaries_output_list);
412        } else {
413          const int64 sparse_index = i - num_dense_features_;
414          const auto sparse_values =
415              sparse_float_feature_values_list[sparse_index].flat<float>();
416          const auto sparse_indices =
417              sparse_float_feature_indices_list[sparse_index].matrix<int64>();
418          const auto dense_shape =
419              sparse_float_feature_shapes_list[sparse_index].flat<int64>();
420          OP_REQUIRES(context, batch_size == dense_shape(0),
421                      errors::InvalidArgument(
422                          "Sparse column shape doesn't match the batch size."));
423          QuantileStream stream(epsilon_, batch_size + 1);
424          // Run quantile summary generation.
425          const int64 num_sparse_rows =
426              sparse_float_feature_indices_list[sparse_index].dim_size(0);
427          for (int64 j = 0; j < num_sparse_rows; ++j) {
428            const int64 example_id = sparse_indices(j, 0);
429            stream.PushEntry(sparse_values(j), example_weights(example_id));
430          }
431          stream.Finalize();
432          // Copy summaries to output.
433          copy_over_summaries(stream, sparse_index,
434                              &sparse_summaries_output_list);
435        }
436      }
437    };
438    const int64 kCostPerUnit = 500 * batch_size;
439    const int64 num_features = num_sparse_features_ + num_dense_features_;
440    const DeviceBase::CpuWorkerThreads& worker_threads =
441        *context->device()->tensorflow_cpu_worker_threads();
442    Shard(worker_threads.num_threads, worker_threads.workers, num_features,
443          kCostPerUnit, do_quantile_summary_gen);
444  }
445
446 private:
447  int num_dense_features_;
448  int num_sparse_features_;
449  float epsilon_;
450};
451
452REGISTER_KERNEL_BUILDER(Name("MakeQuantileSummaries").Device(DEVICE_CPU),
453                        MakeQuantileSummariesOp);
454
455// Serializes the state of streams.
456class QuantileAccumulatorSerializeOp : public OpKernel {
457 public:
458  explicit QuantileAccumulatorSerializeOp(OpKernelConstruction* const context)
459      : OpKernel(context) {}
460
461  void Compute(OpKernelContext* context) override {
462    QuantileStreamResource* streams_resource;
463    // Create a reference to the underlying resource using the handle.
464    OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
465                                           &streams_resource));
466    // Remove the reference at the end of this scope.
467    mutex_lock l(*streams_resource->mutex());
468    core::ScopedUnref unref_me(streams_resource);
469
470    int64 stamp_token = streams_resource->stamp();
471    Tensor* stream_state_t;
472    OP_REQUIRES_OK(context,
473                   context->allocate_output(kStreamStateName, TensorShape({}),
474                                            &stream_state_t));
475    bool are_buckets_ready = streams_resource->are_buckets_ready();
476
477    // We are iterating over both dense and sparse features. First we go
478    // through the dense features and then the sparse features.
479    const QuantileStream& stream = *streams_resource->stream(stamp_token);
480    const std::vector<float>& boundaries =
481        are_buckets_ready ? streams_resource->boundaries(stamp_token)
482                          : std::vector<float>();
483    protobuf::Arena arena;
484    ::boosted_trees::QuantileStreamState* stream_proto =
485        protobuf::Arena::CreateMessage<::boosted_trees::QuantileStreamState>(
486            &arena);
487    for (const auto& summary : stream.SerializeInternalSummaries()) {
488      CopySummaryToProto(summary, stream_proto->add_summaries());
489    }
490    stream_proto->SerializeToString(&stream_state_t->scalar<string>()());
491    Tensor* buckets_t = nullptr;
492    OP_REQUIRES_OK(
493        context,
494        context->allocate_output(
495            kBucketsName, {static_cast<int64>(boundaries.size())}, &buckets_t));
496    auto* quantiles_flat = buckets_t->flat<float>().data();
497    memcpy(quantiles_flat, boundaries.data(),
498           sizeof(float) * boundaries.size());
499    Tensor* stamp_token_t = nullptr;
500    OP_REQUIRES_OK(context,
501                   context->allocate_output(kStampTokenName, TensorShape({}),
502                                            &stamp_token_t));
503    stamp_token_t->scalar<int64>()() = stamp_token;
504    Tensor* are_buckets_ready_t = nullptr;
505    OP_REQUIRES_OK(context, context->allocate_output(kAreBucketsReadyName, {},
506                                                     &are_buckets_ready_t));
507    are_buckets_ready_t->scalar<bool>()() = are_buckets_ready;
508  }
509};
510
511REGISTER_KERNEL_BUILDER(Name("QuantileAccumulatorSerialize").Device(DEVICE_CPU),
512                        QuantileAccumulatorSerializeOp);
513
514// Serializes the state of streams.
515class QuantileAccumulatorDeserializeOp : public OpKernel {
516 public:
517  explicit QuantileAccumulatorDeserializeOp(OpKernelConstruction* const context)
518      : OpKernel(context) {}
519
520  void Compute(OpKernelContext* context) override {
521    QuantileStreamResource* streams_resource;
522    // Create a reference to the underlying resource using the handle.
523    OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
524                                           &streams_resource));
525    // Remove the reference at the end of this scope.
526    mutex_lock l(*streams_resource->mutex());
527    core::ScopedUnref unref_me(streams_resource);
528
529    int64 old_stamp_token = streams_resource->stamp();
530
531    const Tensor* stream_state_t;
532    OP_REQUIRES_OK(context, context->input(kStreamStateName, &stream_state_t));
533    const Tensor* buckets_t;
534    OP_REQUIRES_OK(context, context->input(kBucketsName, &buckets_t));
535
536    QuantileStream* stream = streams_resource->stream(old_stamp_token);
537    ::boosted_trees::QuantileStreamState state_proto;
538    OP_REQUIRES(
539        context,
540        ParseProtoUnlimited(&state_proto, stream_state_t->scalar<string>()()),
541        errors::InvalidArgument("Unabnle to parse quantile stream state."));
542    std::vector<QuantileSummary> summaries;
543    summaries.reserve(state_proto.summaries_size());
544    std::vector<QuantileSummaryEntry> entries;
545    for (const auto& summary : state_proto.summaries()) {
546      entries.clear();
547      entries.reserve(summary.entries_size());
548      for (const auto& entry : summary.entries()) {
549        entries.emplace_back(entry.value(), entry.weight(), entry.min_rank(),
550                             entry.max_rank());
551      }
552      summaries.emplace_back();
553      summaries[summaries.size() - 1].BuildFromSummaryEntries(entries);
554    }
555    stream->DeserializeInternalSummaries(summaries);
556
557    const auto& buckets = buckets_t->vec<float>();
558    std::vector<float> result;
559    result.reserve(buckets.size());
560
561    for (size_t i = 0; i < buckets.size(); ++i) {
562      result.push_back(buckets(i));
563    }
564    streams_resource->set_boundaries(old_stamp_token, result);
565
566    // Reset the stamp token.
567    const Tensor* stamp_token_t = nullptr;
568    OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
569    int64 stamp_token = stamp_token_t->scalar<int64>()();
570    streams_resource->set_stamp(stamp_token);
571
572    const Tensor* are_buckets_ready_t = nullptr;
573    OP_REQUIRES_OK(context,
574                   context->input(kAreBucketsReadyName, &are_buckets_ready_t));
575    streams_resource->set_buckets_ready(are_buckets_ready_t->scalar<bool>()());
576  }
577};
578
579REGISTER_KERNEL_BUILDER(
580    Name("QuantileAccumulatorDeserialize").Device(DEVICE_CPU),
581    QuantileAccumulatorDeserializeOp);
582
583// Flushes the quantile summary stream resource.
584class QuantileAccumulatorFlushOp : public OpKernel {
585 public:
586  explicit QuantileAccumulatorFlushOp(OpKernelConstruction* const context)
587      : OpKernel(context) {}
588
589  void Compute(OpKernelContext* context) override {
590    QuantileStreamResource* streams_resource;
591    // Create a reference to the underlying resource using the handle.
592    OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
593                                           &streams_resource));
594    // Remove the reference at the end of this scope.
595    mutex_lock l(*streams_resource->mutex());
596    core::ScopedUnref unref_me(streams_resource);
597
598    const Tensor* next_stamp_token_t;
599    OP_REQUIRES_OK(context,
600                   context->input(kNextStampTokenName, &next_stamp_token_t));
601    int64 next_stamp_token = next_stamp_token_t->scalar<int64>()();
602
603    const Tensor* stamp_token_t;
604    OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
605    int64 stamp_token = stamp_token_t->scalar<int64>()();
606    CHECK(streams_resource->is_stamp_valid(stamp_token))
607        << "Invalid stamp token in QuantileAccumulatorFlushOp. "
608        << "Passed stamp token: " << stamp_token << " "
609        << "Current token: " << streams_resource->stamp();
610    QuantileStream* stream = streams_resource->stream(stamp_token);
611    bool generate_quantiles = streams_resource->generate_quantiles();
612    stream->Finalize();
613
614    streams_resource->set_boundaries(
615        stamp_token,
616        generate_quantiles
617            ? GenerateQuantiles(*stream, streams_resource->num_quantiles())
618            : GenerateBoundaries(*stream, streams_resource->num_quantiles()));
619
620    streams_resource->Reset(next_stamp_token);
621  }
622};
623
624REGISTER_KERNEL_BUILDER(Name("QuantileAccumulatorFlush").Device(DEVICE_CPU),
625                        QuantileAccumulatorFlushOp);
626
627// Flushes the quantile summary stream resource. This version computes the
628// summary.
629class QuantileAccumulatorFlushSummaryOp : public OpKernel {
630 public:
631  explicit QuantileAccumulatorFlushSummaryOp(
632      OpKernelConstruction* const context)
633      : OpKernel(context) {}
634
635  void Compute(OpKernelContext* context) override {
636    QuantileStreamResource* streams_resource;
637    // Create a reference to the underlying resource using the handle.
638    OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
639                                           &streams_resource));
640    // Remove the reference at the end of this scope.
641    mutex_lock l(*streams_resource->mutex());
642    core::ScopedUnref unref_me(streams_resource);
643
644    const Tensor* next_stamp_token_t;
645    OP_REQUIRES_OK(context,
646                   context->input(kNextStampTokenName, &next_stamp_token_t));
647    int64 next_stamp_token = next_stamp_token_t->scalar<int64>()();
648
649    const Tensor* stamp_token_t;
650    OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
651    int64 stamp_token = stamp_token_t->scalar<int64>()();
652    CHECK(streams_resource->is_stamp_valid(stamp_token))
653        << "Invalid stamp token in QuantileAccumulatorFlushSummaryOp. "
654        << "Passed stamp token: " << stamp_token << " "
655        << "Current token: " << streams_resource->stamp();
656    QuantileStream* stream = streams_resource->stream(stamp_token);
657    stream->Finalize();
658    protobuf::Arena arena;
659    ::boosted_trees::QuantileSummaryState* summary_proto =
660        protobuf::Arena::CreateMessage<::boosted_trees::QuantileSummaryState>(
661            &arena);
662    const auto& summary = stream->GetFinalSummary();
663    CopySummaryToProto(summary, summary_proto);
664    // Output to tensor.
665    Tensor* output_t = nullptr;
666    OP_REQUIRES_OK(context,
667                   context->allocate_output(0, TensorShape({}), &output_t));
668    summary_proto->SerializeToString(&output_t->scalar<string>()());
669    streams_resource->Reset(next_stamp_token);
670  }
671};
672
673REGISTER_KERNEL_BUILDER(
674    Name("QuantileAccumulatorFlushSummary").Device(DEVICE_CPU),
675    QuantileAccumulatorFlushSummaryOp);
676
677// Get bucket boundaries from summaries.
678class QuantileAccumulatorGetBucketsOp : public OpKernel {
679 public:
680  explicit QuantileAccumulatorGetBucketsOp(OpKernelConstruction* const context)
681      : OpKernel(context) {}
682
683  void Compute(OpKernelContext* const context) override {
684    OpInputList resource_handle_list;
685    OP_REQUIRES_OK(context, context->input_list(kResourceHandlesName,
686                                                &resource_handle_list));
687    OpOutputList are_buckets_ready_list;
688    OP_REQUIRES_OK(context, context->output_list(kAreBucketsReadyName,
689                                                 &are_buckets_ready_list));
690    OpOutputList buckets_list;
691    OP_REQUIRES_OK(context, context->output_list(kBucketsName, &buckets_list));
692    const Tensor* stamp_token_t;
693    OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
694    int64 stamp_token = stamp_token_t->scalar<int64>()();
695
696    thread::ThreadPool* const worker_threads =
697        context->device()->tensorflow_cpu_worker_threads()->workers;
698    boosted_trees::utils::ParallelFor(
699        resource_handle_list.size(), worker_threads->NumThreads(),
700        worker_threads,
701        [&context, &resource_handle_list, &are_buckets_ready_list,
702         &buckets_list, stamp_token](int64 start, int64 end) {
703          for (int resource_handle_idx = start; resource_handle_idx < end;
704               ++resource_handle_idx) {
705            ResourceHandle handle = resource_handle_list[resource_handle_idx]
706                                        .flat<ResourceHandle>()(0);
707            QuantileStreamResource* streams_resource;
708            OP_REQUIRES_OK(context,
709                           LookupResource(context, handle, &streams_resource));
710            // Remove the reference at the end of this scope.
711            mutex_lock l(*streams_resource->mutex());
712            core::ScopedUnref unref_me(streams_resource);
713
714            bool are_buckets_ready =
715                streams_resource->is_stamp_valid(stamp_token) &&
716                streams_resource->are_buckets_ready();
717
718            Tensor* are_buckets_ready_t = nullptr;
719            OP_REQUIRES_OK(context,
720                           are_buckets_ready_list.allocate(
721                               resource_handle_idx, {}, &are_buckets_ready_t));
722            are_buckets_ready_t->scalar<bool>()() = are_buckets_ready;
723
724            const std::vector<float>& boundaries =
725                are_buckets_ready ? streams_resource->boundaries(stamp_token)
726                                  : std::vector<float>();
727            Tensor* output_t = nullptr;
728            OP_REQUIRES_OK(context, buckets_list.allocate(
729                                        resource_handle_idx,
730                                        {static_cast<int64>(boundaries.size())},
731                                        &output_t));
732            auto* quantiles_flat = output_t->flat<float>().data();
733            memcpy(quantiles_flat, boundaries.data(),
734                   sizeof(float) * boundaries.size());
735          }
736        });
737  }
738};
739
740REGISTER_KERNEL_BUILDER(
741    Name("QuantileAccumulatorGetBuckets").Device(DEVICE_CPU),
742    QuantileAccumulatorGetBucketsOp);
743
744// Generates buckets for given set of float values, and the given config.
745class QuantileBucketsOp : public OpKernel {
746 public:
747  explicit QuantileBucketsOp(OpKernelConstruction* const context)
748      : OpKernel(context) {
749    OP_REQUIRES_OK(context,
750                   ReadAndValidateAttributes(context, &num_dense_features_,
751                                             &num_sparse_features_));
752
753    ParseConfig(context, kDenseConfigName, &dense_configs_);
754    OP_REQUIRES(context, dense_configs_.size() == num_dense_features_,
755                errors::InvalidArgument(
756                    "Mismatch in number of dense quantile configs."));
757    ParseConfig(context, kSparseConfigName, &sparse_configs_);
758    OP_REQUIRES(context, sparse_configs_.size() == num_sparse_features_,
759                errors::InvalidArgument(
760                    "Mismatch in number of sparse quantile configs."));
761  }
762
763  void Compute(OpKernelContext* const context) override {
764    // Read dense float features list;
765    OpInputList dense_float_features_list;
766    OP_REQUIRES_OK(context, TensorUtils::ReadDenseFloatFeatures(
767                                context, &dense_float_features_list));
768
769    // Read sparse float features list;
770    OpInputList sparse_float_feature_indices_list;
771    OpInputList sparse_float_feature_values_list;
772    OpInputList sparse_float_feature_shapes_list;
773    OP_REQUIRES_OK(context, TensorUtils::ReadSparseFloatFeatures(
774                                context, &sparse_float_feature_indices_list,
775                                &sparse_float_feature_values_list,
776                                &sparse_float_feature_shapes_list));
777
778    // Parse example weights and get batch size.
779    const Tensor* example_weights_t;
780    OP_REQUIRES_OK(context,
781                   context->input(kExampleWeightsName, &example_weights_t));
782    auto example_weights = example_weights_t->flat<float>();
783    const int64 batch_size = example_weights.size();
784
785    OpOutputList sparse_buckets_output_list;
786    OP_REQUIRES_OK(context, context->output_list(kSparseBucketsName,
787                                                 &sparse_buckets_output_list));
788    OpOutputList dense_buckets_output_list;
789    OP_REQUIRES_OK(context, context->output_list(kDenseBucketsName,
790                                                 &dense_buckets_output_list));
791
792    auto do_quantile_bucket_gen = [&](const int64 begin, const int64 end) {
793      // These are blocks of ranges. We are iterating over both sparse and
794      // dense features i.e. [0, sparse_features.size() + dense_features.size()]
795      for (int64 i = begin; i < end; ++i) {
796        if (i < sparse_configs_.size()) {
797          const int64 sparse_index = i;
798          const auto sparse_values =
799              sparse_float_feature_values_list[sparse_index].flat<float>();
800          const auto sparse_indices =
801              sparse_float_feature_indices_list[sparse_index].matrix<int64>();
802          QuantileStream stream(sparse_configs_[sparse_index].eps(),
803                                batch_size);
804          // Run quantile summary generation.
805          const int64 num_sparse_rows =
806              sparse_float_feature_indices_list[sparse_index].dim_size(0);
807          for (int64 j = 0; j < num_sparse_rows; ++j) {
808            const int64 example_id = sparse_indices(j, 0);
809            stream.PushEntry(sparse_values(j), example_weights(example_id));
810          }
811          stream.Finalize();
812          // Create buckets.
813          const auto boundaries = GenerateBoundaries(
814              stream, sparse_configs_[sparse_index].num_quantiles());
815          CopyBoundaries(context, boundaries, sparse_index,
816                         &sparse_buckets_output_list);
817
818        } else {
819          const int64 dense_index = i - sparse_configs_.size();
820          const auto dense_values =
821              dense_float_features_list[dense_index].flat<float>();
822          QuantileStream stream(dense_configs_[dense_index].eps(), batch_size);
823          // Run quantile summary generation.
824          for (int64 j = 0; j < batch_size; ++j) {
825            stream.PushEntry(dense_values(j), example_weights(j));
826          }
827          stream.Finalize();
828          // Create buckets.
829          const auto boundaries = GenerateBoundaries(
830              stream, dense_configs_[dense_index].num_quantiles());
831          CopyBoundaries(context, boundaries, dense_index,
832                         &dense_buckets_output_list);
833        }
834      }
835    };
836
837    const int64 kCostPerUnit = 500 * batch_size;
838    const int64 num_features = sparse_configs_.size() + dense_configs_.size();
839    const DeviceBase::CpuWorkerThreads& worker_threads =
840        *context->device()->tensorflow_cpu_worker_threads();
841    Shard(worker_threads.num_threads, worker_threads.workers, num_features,
842          kCostPerUnit, do_quantile_bucket_gen);
843  }
844
845 private:
846  int num_dense_features_;
847  int num_sparse_features_;
848  std::vector<QuantileConfig> dense_configs_;
849  std::vector<QuantileConfig> sparse_configs_;
850};
851
852REGISTER_KERNEL_BUILDER(Name("QuantileBuckets").Device(DEVICE_CPU),
853                        QuantileBucketsOp);
854
855// Given the calculated quantiles thresholds and input data, this operation
856// converts the input features into the buckets (categorical values), depending
857// on which quantile they fall into.
858class QuantilesOp : public OpKernel {
859 public:
860  explicit QuantilesOp(OpKernelConstruction* const context)
861      : OpKernel(context) {
862    int num_dense_features;
863    int num_sparse_features;
864    OP_REQUIRES_OK(context,
865                   ReadAndValidateAttributes(context, &num_dense_features,
866                                             &num_sparse_features));
867  }
868
869  void Compute(OpKernelContext* const context) override {
870    // Dense features inputs
871    OpInputList dense_float_features_list;
872    OP_REQUIRES_OK(context, context->input_list(kDenseValuesName,
873                                                &dense_float_features_list));
874    OpInputList dense_buckets_list;
875    OP_REQUIRES_OK(context,
876                   context->input_list(kDenseBucketsName, &dense_buckets_list));
877
878    if (dense_buckets_list.size() > 0) {
879      // Check the first tensor to make sure it is the right shape
880      OP_REQUIRES(
881          context,
882          tensorflow::TensorShapeUtils::IsVector(dense_buckets_list[0].shape()),
883          errors::InvalidArgument(
884              strings::Printf("Dense buckets should be flat vectors")));
885    }
886
887    // Sparse features inputs
888    OpInputList sparse_float_feature_values_list;
889    OP_REQUIRES_OK(context,
890                   context->input_list(kSparseValuesName,
891                                       &sparse_float_feature_values_list));
892
893    OpInputList sparse_float_indices_list;
894    OP_REQUIRES_OK(context, context->input_list(kSparseIndicesName,
895                                                &sparse_float_indices_list));
896
897    OpInputList sparse_buckets_list;
898    OP_REQUIRES_OK(
899        context, context->input_list(kSparseBucketsName, &sparse_buckets_list));
900
901    if (sparse_buckets_list.size() > 0) {
902      OP_REQUIRES(
903          context,
904          tensorflow::TensorShapeUtils::IsVector(
905              sparse_buckets_list[0].shape()),
906          errors::InvalidArgument("Sparse buckets should be flat vectors"));
907    }
908
909    // Quantize the feature values
910    QuantizeFeatures(kDenseOutputTensorName, dense_float_features_list,
911                     dense_buckets_list, nullptr, context);
912
913    QuantizeFeatures(kSparseOutputTensorName, sparse_float_feature_values_list,
914                     sparse_buckets_list, &sparse_float_indices_list, context);
915  }
916};
917
918REGISTER_KERNEL_BUILDER(Name("Quantiles").Device(DEVICE_CPU), QuantilesOp);
919
920template <typename T>
921class BucketizeWithInputBoundariesOp : public OpKernel {
922 public:
923  explicit BucketizeWithInputBoundariesOp(OpKernelConstruction* context)
924      : OpKernel(context) {}
925
926  void Compute(OpKernelContext* context) override {
927    const Tensor& boundaries_tensor = context->input(1);
928    VLOG(1) << "boundaries has shape: "
929            << boundaries_tensor.shape().DebugString();
930    auto boundaries = boundaries_tensor.flat<float>();
931    std::vector<T> boundaries_vector;
932    boundaries_vector.reserve(boundaries.size());
933    for (size_t i = 0; i < boundaries.size(); i++) {
934      boundaries_vector.push_back(boundaries(i));
935      VLOG(1) << "boundaries(" << i << ") : " << boundaries(i);
936    }
937    OP_REQUIRES(
938        context,
939        std::is_sorted(boundaries_vector.begin(), boundaries_vector.end()),
940        errors::InvalidArgument("Expected sorted boundaries"));
941
942    const Tensor& input_tensor = context->input(0);
943    VLOG(1) << "Inputs has shape: " << input_tensor.shape().DebugString()
944            << " Dtype: " << tensorflow::DataTypeString(input_tensor.dtype());
945    auto input = input_tensor.flat<T>();
946
947    Tensor* output_tensor = nullptr;
948    OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
949                                                     &output_tensor));
950    auto output = output_tensor->template flat<int32>();
951
952    for (size_t i = 0; i < input.size(); i++) {
953      output(i) = CalculateBucketIndex(input(i), boundaries_vector);
954    }
955  }
956
957 private:
958  int32 CalculateBucketIndex(const T value, std::vector<T>& boundaries_vector) {
959    auto first_bigger_it = std::upper_bound(boundaries_vector.begin(),
960                                            boundaries_vector.end(), value);
961    int32 index = first_bigger_it - boundaries_vector.begin();
962    CHECK(index >= 0 && index <= boundaries_vector.size())
963        << "Invalid bucket index: " << index
964        << " boundaries_vector.size(): " << boundaries_vector.size();
965    return index;
966  }
967};
968
969#define REGISTER_KERNEL(T)                                     \
970  REGISTER_KERNEL_BUILDER(Name("BucketizeWithInputBoundaries") \
971                              .Device(DEVICE_CPU)              \
972                              .TypeConstraint<T>("T"),         \
973                          BucketizeWithInputBoundariesOp<T>);
974
975REGISTER_KERNEL(int32);
976REGISTER_KERNEL(int64);
977REGISTER_KERNEL(float);
978REGISTER_KERNEL(double);
979#undef REGISTER_KERNEL
980
981}  // namespace tensorflow
982