1// Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14// ============================================================================= 15#include <algorithm> 16#include <iterator> 17#include <string> 18#include <vector> 19 20#include "tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream.h" 21#include "tensorflow/contrib/boosted_trees/lib/utils/parallel_for.h" 22#include "tensorflow/contrib/boosted_trees/lib/utils/tensor_utils.h" 23#include "tensorflow/contrib/boosted_trees/proto/quantiles.pb.h" 24#include "tensorflow/contrib/boosted_trees/resources/quantile_stream_resource.h" 25#include "tensorflow/core/framework/op_kernel.h" 26#include "tensorflow/core/framework/resource_mgr.h" 27#include "tensorflow/core/framework/tensor.h" 28#include "tensorflow/core/framework/tensor_shape.h" 29#include "tensorflow/core/framework/types.h" 30#include "tensorflow/core/lib/core/errors.h" 31#include "tensorflow/core/lib/core/status.h" 32#include "tensorflow/core/lib/strings/stringprintf.h" 33#include "tensorflow/core/platform/types.h" 34#include "tensorflow/core/util/work_sharder.h" 35 36namespace tensorflow { 37 38using ::boosted_trees::QuantileConfig; 39using boosted_trees::QuantileStreamResource; 40using boosted_trees::utils::TensorUtils; 41 42namespace { 43const char* const kExampleWeightsName = "example_weights"; 44const char* const kMaxElementsName = "max_elements"; 45const char* const kNextStampTokenName = "next_stamp_token"; 46const char* const kStampTokenName = "stamp_token"; 47const char* const kAreBucketsReadyName = "are_buckets_ready"; 48const char* const kGenerateQuantiles = "generate_quantiles"; 49// Names for sparse arguments. 50const char* const kNumSparseFeaturesName = "num_sparse_features"; 51const char* const kSparseBucketsName = "sparse_buckets"; 52const char* const kSparseValuesName = "sparse_values"; 53const char* const kSparseIndicesName = "sparse_indices"; 54const char* const kSparseSummariesName = "sparse_summaries"; 55const char* const kSparseConfigName = "sparse_config"; 56const char* const kSparseOutputTensorName = "sparse_quantiles"; 57// Names for dense arguments. 58const char* const kDenseBucketsName = "dense_buckets"; 59const char* const kDenseConfigName = "dense_config"; 60const char* const kDenseOutputTensorName = "dense_quantiles"; 61const char* const kDenseSummariesName = "dense_summaries"; 62const char* const kDenseValuesName = "dense_values"; 63const char* const kNumDenseFeaturesName = "num_dense_features"; 64const char* const kResourceHandlesName = "quantile_accumulator_handles"; 65const char* const kNumQuantilesName = "num_quantiles"; 66const char* const kEpsilonName = "epsilon"; 67const char* const kBucketsName = "buckets"; 68const char* const kStreamStateName = "stream_state"; 69const char* const kSummariesName = "summaries"; 70 71using QuantileStream = 72 boosted_trees::quantiles::WeightedQuantilesStream<float, float>; 73using QuantileSummary = 74 boosted_trees::quantiles::WeightedQuantilesSummary<float, float>; 75using QuantileSummaryEntry = 76 boosted_trees::quantiles::WeightedQuantilesSummary<float, 77 float>::SummaryEntry; 78 79std::vector<float> GetBuckets(const int32 feature, 80 const OpInputList& buckets_list) { 81 const auto& buckets = buckets_list[feature].flat<float>(); 82 std::vector<float> buckets_vector(buckets.data(), 83 buckets.data() + buckets.size()); 84 return buckets_vector; 85} 86 87int32 GetFeatureDimension(const int32 feature_index, const int64 instance, 88 const OpInputList* const indices_list) { 89 if (indices_list != nullptr) { 90 // Sparse multidimensional. 91 return (*indices_list)[feature_index].matrix<int64>()(instance, 1); 92 } 93 // No indices, assume one-dimensional tensor. 94 return 0; 95} 96 97// Allows quantization for each of multiple dimensions of a sparse feature. 98void QuantizeFeatures( 99 const string& output_name, const OpInputList& values_list, 100 const OpInputList& buckets_list, 101 const OpInputList* const 102 indices_list /** Optional, provide for sparse features **/, 103 OpKernelContext* const context) { 104 if (values_list.size() == 0) { 105 return; 106 } 107 OpOutputList output_list; 108 OP_REQUIRES_OK(context, context->output_list(output_name, &output_list)); 109 110 for (int32 feature_index = 0; feature_index < values_list.size(); 111 ++feature_index) { 112 const Tensor& values_tensor = values_list[feature_index]; 113 const int64 num_values = values_tensor.dim_size(0); 114 115 Tensor* output_t = nullptr; 116 // Output will have bucket id and dimension of the features for that bucket. 117 OP_REQUIRES_OK( 118 context, output_list.allocate(feature_index, 119 TensorShape({num_values, 2}), &output_t)); 120 121 auto output = output_t->matrix<int32>(); 122 123 const std::vector<float>& buckets_vector = 124 GetBuckets(feature_index, buckets_list); 125 auto flat_values = values_tensor.flat<float>(); 126 for (int64 instance = 0; instance < num_values; ++instance) { 127 const float value = flat_values(instance); 128 auto bucket_iter = 129 std::lower_bound(buckets_vector.begin(), buckets_vector.end(), value); 130 if (bucket_iter == buckets_vector.end()) { 131 --bucket_iter; 132 } 133 const int32 bucket = 134 static_cast<int32>(bucket_iter - buckets_vector.begin()); 135 // Bucket id. 136 output(instance, 0) = bucket; 137 // Dimension. 138 output(instance, 1) = 139 GetFeatureDimension(feature_index, instance, indices_list); 140 } 141 } 142} 143 144// Validates attributes for the quantile ops. 145Status ReadAndValidateAttributes(OpKernelConstruction* const context, 146 int* num_dense_features, 147 int* num_sparse_features) { 148 TF_RETURN_IF_ERROR( 149 context->GetAttr(kNumDenseFeaturesName, num_dense_features)); 150 TF_RETURN_IF_ERROR( 151 context->GetAttr(kNumSparseFeaturesName, num_sparse_features)); 152 if ((*num_dense_features) + (*num_sparse_features) == 0) { 153 return errors::InvalidArgument( 154 "Please provide at least sparse or dense features."); 155 } 156 return Status::OK(); 157} 158 159void ParseConfig(OpKernelConstruction* const context, const string& name, 160 std::vector<QuantileConfig>* output) { 161 std::vector<string> serialized_config; 162 OP_REQUIRES_OK(context, context->GetAttr(name, &serialized_config)); 163 output->reserve(serialized_config.size()); 164 QuantileConfig tmp; 165 for (const auto& serialized_string : serialized_config) { 166 OP_REQUIRES(context, tmp.ParseFromString(serialized_string), 167 errors::InvalidArgument("Malformed QuantileConfig passed in.")); 168 output->push_back(tmp); 169 } 170} 171 172// Generates quantiles on a finalized QuantileStream. 173std::vector<float> GenerateBoundaries(const QuantileStream& stream, 174 int num_boundaries) { 175 std::vector<float> boundaries = stream.GenerateBoundaries(num_boundaries); 176 177 // Uniquify elements as we may get dupes. 178 auto end_it = std::unique(boundaries.begin(), boundaries.end()); 179 boundaries.resize(std::distance(boundaries.begin(), end_it)); 180 return boundaries; 181} 182 183// Generates quantiles on a finalized QuantileStream. 184std::vector<float> GenerateQuantiles(const QuantileStream& stream, 185 int num_quantiles) { 186 // Do not de-dup boundaries. Exactly num_quantiles+1 boundary values 187 // will be returned. 188 std::vector<float> boundaries = stream.GenerateQuantiles(num_quantiles); 189 CHECK_EQ(boundaries.size(), num_quantiles + 1); 190 return boundaries; 191} 192 193// Copies quantiles to output list. 194void CopyBoundaries(OpKernelContext* const context, 195 const std::vector<float>& boundaries, const int64 index, 196 OpOutputList* output_list) { 197 // Output to tensor. 198 Tensor* output_t = nullptr; 199 OP_REQUIRES_OK( 200 context, output_list->allocate( 201 index, {static_cast<int64>(boundaries.size())}, &output_t)); 202 auto* quantiles_flat = output_t->flat<float>().data(); 203 memcpy(quantiles_flat, boundaries.data(), sizeof(float) * boundaries.size()); 204} 205 206void CopySummaryToProto(const QuantileSummary& summary, 207 ::boosted_trees::QuantileSummaryState* summary_proto) { 208 summary_proto->mutable_entries()->Reserve(summary.Size()); 209 for (const auto& entry : summary.GetEntryList()) { 210 auto* new_entry = summary_proto->add_entries(); 211 new_entry->set_value(entry.value); 212 new_entry->set_weight(entry.weight); 213 new_entry->set_min_rank(entry.min_rank); 214 new_entry->set_max_rank(entry.max_rank); 215 } 216} 217 218} // namespace 219 220// Accumulator for Quantile Summaries. 221REGISTER_RESOURCE_HANDLE_KERNEL(QuantileStreamResource); 222 223REGISTER_KERNEL_BUILDER( 224 Name("QuantileAccumulatorIsInitialized").Device(DEVICE_CPU), 225 IsResourceInitialized<QuantileStreamResource>); 226 227class CreateQuantileAccumulatorOp : public OpKernel { 228 public: 229 explicit CreateQuantileAccumulatorOp(OpKernelConstruction* const context) 230 : OpKernel(context) { 231 OP_REQUIRES_OK(context, context->GetAttr(kEpsilonName, &epsilon_)); 232 OP_REQUIRES_OK(context, 233 context->GetAttr(kNumQuantilesName, &num_quantiles_)); 234 OP_REQUIRES_OK(context, context->GetAttr(kMaxElementsName, &max_elements_)); 235 OP_REQUIRES_OK(context, 236 context->GetAttr(kGenerateQuantiles, &generate_quantiles_)); 237 } 238 239 void Compute(OpKernelContext* context) override { 240 // Only create one, if one does not exist already. Report status for all 241 // other exceptions. If one already exists, it unrefs the new one. 242 const Tensor* stamp_token_t; 243 OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t)); 244 auto result = new QuantileStreamResource(epsilon_, num_quantiles_, 245 max_elements_, generate_quantiles_, 246 stamp_token_t->scalar<int64>()()); 247 auto status = CreateResource(context, HandleFromInput(context, 0), result); 248 if (!status.ok() && status.code() != tensorflow::error::ALREADY_EXISTS) { 249 OP_REQUIRES(context, false, status); 250 } 251 } 252 253 private: 254 float epsilon_; 255 int32 num_quantiles_; 256 // An upperbound on the number of enteries that the summaries might have 257 // for a feature. 258 int64 max_elements_; 259 bool generate_quantiles_; 260}; 261 262REGISTER_KERNEL_BUILDER(Name("CreateQuantileAccumulator").Device(DEVICE_CPU), 263 CreateQuantileAccumulatorOp); 264 265// Adds a summary to the quantile summary stream. 266class QuantileAccumulatorAddSummariesOp : public OpKernel { 267 public: 268 explicit QuantileAccumulatorAddSummariesOp( 269 OpKernelConstruction* const context) 270 : OpKernel(context) {} 271 272 void Compute(OpKernelContext* context) override { 273 OpInputList resource_handle_list; 274 OP_REQUIRES_OK(context, context->input_list(kResourceHandlesName, 275 &resource_handle_list)); 276 OpInputList summary_list; 277 OP_REQUIRES_OK(context, context->input_list(kSummariesName, &summary_list)); 278 279 const Tensor* stamp_token_t; 280 OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t)); 281 int64 stamp_token = stamp_token_t->scalar<int64>()(); 282 283 thread::ThreadPool* const worker_threads = 284 context->device()->tensorflow_cpu_worker_threads()->workers; 285 boosted_trees::utils::ParallelFor( 286 resource_handle_list.size(), worker_threads->NumThreads(), 287 worker_threads, 288 [&context, &resource_handle_list, &summary_list, stamp_token]( 289 int64 start, int64 end) { 290 for (int resource_handle_idx = start; resource_handle_idx < end; 291 ++resource_handle_idx) { 292 ResourceHandle handle = resource_handle_list[resource_handle_idx] 293 .flat<ResourceHandle>()(0); 294 QuantileStreamResource* streams_resource; 295 // Create a reference to the underlying resource using the handle. 296 OP_REQUIRES_OK(context, 297 LookupResource(context, handle, &streams_resource)); 298 // Remove the reference at the end of this scope. 299 mutex_lock l(*streams_resource->mutex()); 300 core::ScopedUnref unref_me(streams_resource); 301 302 // If the stamp is invalid we drop the update. 303 if (!streams_resource->is_stamp_valid(stamp_token)) { 304 VLOG(1) 305 << "Invalid stamp token in QuantileAccumulatorAddSummariesOp." 306 << " Passed stamp token: " << stamp_token << " " 307 << "Current token: " << streams_resource->stamp(); 308 return; 309 } 310 311 protobuf::Arena arena; 312 ::boosted_trees::QuantileSummaryState* summary_proto = 313 protobuf::Arena::CreateMessage< 314 ::boosted_trees::QuantileSummaryState>(&arena); 315 OP_REQUIRES( 316 context, 317 ParseProtoUnlimited( 318 summary_proto, 319 summary_list[resource_handle_idx].scalar<string>()()), 320 errors::InvalidArgument("Unable to parse quantile summary.")); 321 std::vector<QuantileSummaryEntry> entries; 322 entries.reserve(summary_proto->entries_size()); 323 for (const auto& entry : summary_proto->entries()) { 324 entries.emplace_back(entry.value(), entry.weight(), 325 entry.min_rank(), entry.max_rank()); 326 } 327 328 // Add the summary to the quantile stream. 329 streams_resource->stream(stamp_token)->PushSummary(entries); 330 } 331 }); 332 } 333}; 334 335REGISTER_KERNEL_BUILDER( 336 Name("QuantileAccumulatorAddSummaries").Device(DEVICE_CPU), 337 QuantileAccumulatorAddSummariesOp); 338 339// Generates summaries for given set of float values, and the given config. 340class MakeQuantileSummariesOp : public OpKernel { 341 public: 342 explicit MakeQuantileSummariesOp(OpKernelConstruction* const context) 343 : OpKernel(context) { 344 OP_REQUIRES_OK(context, 345 ReadAndValidateAttributes(context, &num_dense_features_, 346 &num_sparse_features_)); 347 OP_REQUIRES_OK(context, context->GetAttr(kEpsilonName, &epsilon_)); 348 } 349 350 void Compute(OpKernelContext* const context) override { 351 // Read dense float features list; 352 OpInputList dense_float_features_list; 353 OP_REQUIRES_OK(context, TensorUtils::ReadDenseFloatFeatures( 354 context, &dense_float_features_list)); 355 356 // Read sparse float features list; 357 OpInputList sparse_float_feature_indices_list; 358 OpInputList sparse_float_feature_values_list; 359 OpInputList sparse_float_feature_shapes_list; 360 OP_REQUIRES_OK(context, TensorUtils::ReadSparseFloatFeatures( 361 context, &sparse_float_feature_indices_list, 362 &sparse_float_feature_values_list, 363 &sparse_float_feature_shapes_list)); 364 365 // Parse example weights and get batch size. 366 const Tensor* example_weights_t; 367 OP_REQUIRES_OK(context, 368 context->input(kExampleWeightsName, &example_weights_t)); 369 auto example_weights = example_weights_t->flat<float>(); 370 const int64 batch_size = example_weights.size(); 371 372 OpOutputList sparse_summaries_output_list; 373 OP_REQUIRES_OK(context, 374 context->output_list(kSparseSummariesName, 375 &sparse_summaries_output_list)); 376 OpOutputList dense_summaries_output_list; 377 OP_REQUIRES_OK(context, context->output_list(kDenseSummariesName, 378 &dense_summaries_output_list)); 379 380 auto do_quantile_summary_gen = [&](const int64 begin, const int64 end) { 381 auto copy_over_summaries = [&](const QuantileStream& stream, 382 const int64 index, 383 OpOutputList* output_list) { 384 protobuf::Arena arena; 385 ::boosted_trees::QuantileSummaryState* summary_proto = 386 protobuf::Arena::CreateMessage< 387 ::boosted_trees::QuantileSummaryState>(&arena); 388 const auto& summary = stream.GetFinalSummary(); 389 CopySummaryToProto(summary, summary_proto); 390 // Output to tensor. 391 Tensor* output_t = nullptr; 392 OP_REQUIRES_OK(context, output_list->allocate(index, {}, &output_t)); 393 summary_proto->SerializeToString(&output_t->scalar<string>()()); 394 }; 395 396 // These are blocks of ranges. We are iterating over both sparse and 397 // dense features i.e. [0, sparse_features.size() + dense_features.size()] 398 for (int64 i = begin; i < end; ++i) { 399 if (i < num_dense_features_) { 400 const int64 dense_index = i; 401 const auto dense_values = 402 dense_float_features_list[dense_index].flat<float>(); 403 QuantileStream stream(epsilon_, batch_size + 1); 404 // Run quantile summary generation. 405 for (int64 j = 0; j < batch_size; ++j) { 406 stream.PushEntry(dense_values(j), example_weights(j)); 407 } 408 stream.Finalize(); 409 // Copy summaries to output. 410 copy_over_summaries(stream, dense_index, 411 &dense_summaries_output_list); 412 } else { 413 const int64 sparse_index = i - num_dense_features_; 414 const auto sparse_values = 415 sparse_float_feature_values_list[sparse_index].flat<float>(); 416 const auto sparse_indices = 417 sparse_float_feature_indices_list[sparse_index].matrix<int64>(); 418 const auto dense_shape = 419 sparse_float_feature_shapes_list[sparse_index].flat<int64>(); 420 OP_REQUIRES(context, batch_size == dense_shape(0), 421 errors::InvalidArgument( 422 "Sparse column shape doesn't match the batch size.")); 423 QuantileStream stream(epsilon_, batch_size + 1); 424 // Run quantile summary generation. 425 const int64 num_sparse_rows = 426 sparse_float_feature_indices_list[sparse_index].dim_size(0); 427 for (int64 j = 0; j < num_sparse_rows; ++j) { 428 const int64 example_id = sparse_indices(j, 0); 429 stream.PushEntry(sparse_values(j), example_weights(example_id)); 430 } 431 stream.Finalize(); 432 // Copy summaries to output. 433 copy_over_summaries(stream, sparse_index, 434 &sparse_summaries_output_list); 435 } 436 } 437 }; 438 const int64 kCostPerUnit = 500 * batch_size; 439 const int64 num_features = num_sparse_features_ + num_dense_features_; 440 const DeviceBase::CpuWorkerThreads& worker_threads = 441 *context->device()->tensorflow_cpu_worker_threads(); 442 Shard(worker_threads.num_threads, worker_threads.workers, num_features, 443 kCostPerUnit, do_quantile_summary_gen); 444 } 445 446 private: 447 int num_dense_features_; 448 int num_sparse_features_; 449 float epsilon_; 450}; 451 452REGISTER_KERNEL_BUILDER(Name("MakeQuantileSummaries").Device(DEVICE_CPU), 453 MakeQuantileSummariesOp); 454 455// Serializes the state of streams. 456class QuantileAccumulatorSerializeOp : public OpKernel { 457 public: 458 explicit QuantileAccumulatorSerializeOp(OpKernelConstruction* const context) 459 : OpKernel(context) {} 460 461 void Compute(OpKernelContext* context) override { 462 QuantileStreamResource* streams_resource; 463 // Create a reference to the underlying resource using the handle. 464 OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), 465 &streams_resource)); 466 // Remove the reference at the end of this scope. 467 mutex_lock l(*streams_resource->mutex()); 468 core::ScopedUnref unref_me(streams_resource); 469 470 int64 stamp_token = streams_resource->stamp(); 471 Tensor* stream_state_t; 472 OP_REQUIRES_OK(context, 473 context->allocate_output(kStreamStateName, TensorShape({}), 474 &stream_state_t)); 475 bool are_buckets_ready = streams_resource->are_buckets_ready(); 476 477 // We are iterating over both dense and sparse features. First we go 478 // through the dense features and then the sparse features. 479 const QuantileStream& stream = *streams_resource->stream(stamp_token); 480 const std::vector<float>& boundaries = 481 are_buckets_ready ? streams_resource->boundaries(stamp_token) 482 : std::vector<float>(); 483 protobuf::Arena arena; 484 ::boosted_trees::QuantileStreamState* stream_proto = 485 protobuf::Arena::CreateMessage<::boosted_trees::QuantileStreamState>( 486 &arena); 487 for (const auto& summary : stream.SerializeInternalSummaries()) { 488 CopySummaryToProto(summary, stream_proto->add_summaries()); 489 } 490 stream_proto->SerializeToString(&stream_state_t->scalar<string>()()); 491 Tensor* buckets_t = nullptr; 492 OP_REQUIRES_OK( 493 context, 494 context->allocate_output( 495 kBucketsName, {static_cast<int64>(boundaries.size())}, &buckets_t)); 496 auto* quantiles_flat = buckets_t->flat<float>().data(); 497 memcpy(quantiles_flat, boundaries.data(), 498 sizeof(float) * boundaries.size()); 499 Tensor* stamp_token_t = nullptr; 500 OP_REQUIRES_OK(context, 501 context->allocate_output(kStampTokenName, TensorShape({}), 502 &stamp_token_t)); 503 stamp_token_t->scalar<int64>()() = stamp_token; 504 Tensor* are_buckets_ready_t = nullptr; 505 OP_REQUIRES_OK(context, context->allocate_output(kAreBucketsReadyName, {}, 506 &are_buckets_ready_t)); 507 are_buckets_ready_t->scalar<bool>()() = are_buckets_ready; 508 } 509}; 510 511REGISTER_KERNEL_BUILDER(Name("QuantileAccumulatorSerialize").Device(DEVICE_CPU), 512 QuantileAccumulatorSerializeOp); 513 514// Serializes the state of streams. 515class QuantileAccumulatorDeserializeOp : public OpKernel { 516 public: 517 explicit QuantileAccumulatorDeserializeOp(OpKernelConstruction* const context) 518 : OpKernel(context) {} 519 520 void Compute(OpKernelContext* context) override { 521 QuantileStreamResource* streams_resource; 522 // Create a reference to the underlying resource using the handle. 523 OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), 524 &streams_resource)); 525 // Remove the reference at the end of this scope. 526 mutex_lock l(*streams_resource->mutex()); 527 core::ScopedUnref unref_me(streams_resource); 528 529 int64 old_stamp_token = streams_resource->stamp(); 530 531 const Tensor* stream_state_t; 532 OP_REQUIRES_OK(context, context->input(kStreamStateName, &stream_state_t)); 533 const Tensor* buckets_t; 534 OP_REQUIRES_OK(context, context->input(kBucketsName, &buckets_t)); 535 536 QuantileStream* stream = streams_resource->stream(old_stamp_token); 537 ::boosted_trees::QuantileStreamState state_proto; 538 OP_REQUIRES( 539 context, 540 ParseProtoUnlimited(&state_proto, stream_state_t->scalar<string>()()), 541 errors::InvalidArgument("Unabnle to parse quantile stream state.")); 542 std::vector<QuantileSummary> summaries; 543 summaries.reserve(state_proto.summaries_size()); 544 std::vector<QuantileSummaryEntry> entries; 545 for (const auto& summary : state_proto.summaries()) { 546 entries.clear(); 547 entries.reserve(summary.entries_size()); 548 for (const auto& entry : summary.entries()) { 549 entries.emplace_back(entry.value(), entry.weight(), entry.min_rank(), 550 entry.max_rank()); 551 } 552 summaries.emplace_back(); 553 summaries[summaries.size() - 1].BuildFromSummaryEntries(entries); 554 } 555 stream->DeserializeInternalSummaries(summaries); 556 557 const auto& buckets = buckets_t->vec<float>(); 558 std::vector<float> result; 559 result.reserve(buckets.size()); 560 561 for (size_t i = 0; i < buckets.size(); ++i) { 562 result.push_back(buckets(i)); 563 } 564 streams_resource->set_boundaries(old_stamp_token, result); 565 566 // Reset the stamp token. 567 const Tensor* stamp_token_t = nullptr; 568 OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t)); 569 int64 stamp_token = stamp_token_t->scalar<int64>()(); 570 streams_resource->set_stamp(stamp_token); 571 572 const Tensor* are_buckets_ready_t = nullptr; 573 OP_REQUIRES_OK(context, 574 context->input(kAreBucketsReadyName, &are_buckets_ready_t)); 575 streams_resource->set_buckets_ready(are_buckets_ready_t->scalar<bool>()()); 576 } 577}; 578 579REGISTER_KERNEL_BUILDER( 580 Name("QuantileAccumulatorDeserialize").Device(DEVICE_CPU), 581 QuantileAccumulatorDeserializeOp); 582 583// Flushes the quantile summary stream resource. 584class QuantileAccumulatorFlushOp : public OpKernel { 585 public: 586 explicit QuantileAccumulatorFlushOp(OpKernelConstruction* const context) 587 : OpKernel(context) {} 588 589 void Compute(OpKernelContext* context) override { 590 QuantileStreamResource* streams_resource; 591 // Create a reference to the underlying resource using the handle. 592 OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), 593 &streams_resource)); 594 // Remove the reference at the end of this scope. 595 mutex_lock l(*streams_resource->mutex()); 596 core::ScopedUnref unref_me(streams_resource); 597 598 const Tensor* next_stamp_token_t; 599 OP_REQUIRES_OK(context, 600 context->input(kNextStampTokenName, &next_stamp_token_t)); 601 int64 next_stamp_token = next_stamp_token_t->scalar<int64>()(); 602 603 const Tensor* stamp_token_t; 604 OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t)); 605 int64 stamp_token = stamp_token_t->scalar<int64>()(); 606 CHECK(streams_resource->is_stamp_valid(stamp_token)) 607 << "Invalid stamp token in QuantileAccumulatorFlushOp. " 608 << "Passed stamp token: " << stamp_token << " " 609 << "Current token: " << streams_resource->stamp(); 610 QuantileStream* stream = streams_resource->stream(stamp_token); 611 bool generate_quantiles = streams_resource->generate_quantiles(); 612 stream->Finalize(); 613 614 streams_resource->set_boundaries( 615 stamp_token, 616 generate_quantiles 617 ? GenerateQuantiles(*stream, streams_resource->num_quantiles()) 618 : GenerateBoundaries(*stream, streams_resource->num_quantiles())); 619 620 streams_resource->Reset(next_stamp_token); 621 } 622}; 623 624REGISTER_KERNEL_BUILDER(Name("QuantileAccumulatorFlush").Device(DEVICE_CPU), 625 QuantileAccumulatorFlushOp); 626 627// Flushes the quantile summary stream resource. This version computes the 628// summary. 629class QuantileAccumulatorFlushSummaryOp : public OpKernel { 630 public: 631 explicit QuantileAccumulatorFlushSummaryOp( 632 OpKernelConstruction* const context) 633 : OpKernel(context) {} 634 635 void Compute(OpKernelContext* context) override { 636 QuantileStreamResource* streams_resource; 637 // Create a reference to the underlying resource using the handle. 638 OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), 639 &streams_resource)); 640 // Remove the reference at the end of this scope. 641 mutex_lock l(*streams_resource->mutex()); 642 core::ScopedUnref unref_me(streams_resource); 643 644 const Tensor* next_stamp_token_t; 645 OP_REQUIRES_OK(context, 646 context->input(kNextStampTokenName, &next_stamp_token_t)); 647 int64 next_stamp_token = next_stamp_token_t->scalar<int64>()(); 648 649 const Tensor* stamp_token_t; 650 OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t)); 651 int64 stamp_token = stamp_token_t->scalar<int64>()(); 652 CHECK(streams_resource->is_stamp_valid(stamp_token)) 653 << "Invalid stamp token in QuantileAccumulatorFlushSummaryOp. " 654 << "Passed stamp token: " << stamp_token << " " 655 << "Current token: " << streams_resource->stamp(); 656 QuantileStream* stream = streams_resource->stream(stamp_token); 657 stream->Finalize(); 658 protobuf::Arena arena; 659 ::boosted_trees::QuantileSummaryState* summary_proto = 660 protobuf::Arena::CreateMessage<::boosted_trees::QuantileSummaryState>( 661 &arena); 662 const auto& summary = stream->GetFinalSummary(); 663 CopySummaryToProto(summary, summary_proto); 664 // Output to tensor. 665 Tensor* output_t = nullptr; 666 OP_REQUIRES_OK(context, 667 context->allocate_output(0, TensorShape({}), &output_t)); 668 summary_proto->SerializeToString(&output_t->scalar<string>()()); 669 streams_resource->Reset(next_stamp_token); 670 } 671}; 672 673REGISTER_KERNEL_BUILDER( 674 Name("QuantileAccumulatorFlushSummary").Device(DEVICE_CPU), 675 QuantileAccumulatorFlushSummaryOp); 676 677// Get bucket boundaries from summaries. 678class QuantileAccumulatorGetBucketsOp : public OpKernel { 679 public: 680 explicit QuantileAccumulatorGetBucketsOp(OpKernelConstruction* const context) 681 : OpKernel(context) {} 682 683 void Compute(OpKernelContext* const context) override { 684 OpInputList resource_handle_list; 685 OP_REQUIRES_OK(context, context->input_list(kResourceHandlesName, 686 &resource_handle_list)); 687 OpOutputList are_buckets_ready_list; 688 OP_REQUIRES_OK(context, context->output_list(kAreBucketsReadyName, 689 &are_buckets_ready_list)); 690 OpOutputList buckets_list; 691 OP_REQUIRES_OK(context, context->output_list(kBucketsName, &buckets_list)); 692 const Tensor* stamp_token_t; 693 OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t)); 694 int64 stamp_token = stamp_token_t->scalar<int64>()(); 695 696 thread::ThreadPool* const worker_threads = 697 context->device()->tensorflow_cpu_worker_threads()->workers; 698 boosted_trees::utils::ParallelFor( 699 resource_handle_list.size(), worker_threads->NumThreads(), 700 worker_threads, 701 [&context, &resource_handle_list, &are_buckets_ready_list, 702 &buckets_list, stamp_token](int64 start, int64 end) { 703 for (int resource_handle_idx = start; resource_handle_idx < end; 704 ++resource_handle_idx) { 705 ResourceHandle handle = resource_handle_list[resource_handle_idx] 706 .flat<ResourceHandle>()(0); 707 QuantileStreamResource* streams_resource; 708 OP_REQUIRES_OK(context, 709 LookupResource(context, handle, &streams_resource)); 710 // Remove the reference at the end of this scope. 711 mutex_lock l(*streams_resource->mutex()); 712 core::ScopedUnref unref_me(streams_resource); 713 714 bool are_buckets_ready = 715 streams_resource->is_stamp_valid(stamp_token) && 716 streams_resource->are_buckets_ready(); 717 718 Tensor* are_buckets_ready_t = nullptr; 719 OP_REQUIRES_OK(context, 720 are_buckets_ready_list.allocate( 721 resource_handle_idx, {}, &are_buckets_ready_t)); 722 are_buckets_ready_t->scalar<bool>()() = are_buckets_ready; 723 724 const std::vector<float>& boundaries = 725 are_buckets_ready ? streams_resource->boundaries(stamp_token) 726 : std::vector<float>(); 727 Tensor* output_t = nullptr; 728 OP_REQUIRES_OK(context, buckets_list.allocate( 729 resource_handle_idx, 730 {static_cast<int64>(boundaries.size())}, 731 &output_t)); 732 auto* quantiles_flat = output_t->flat<float>().data(); 733 memcpy(quantiles_flat, boundaries.data(), 734 sizeof(float) * boundaries.size()); 735 } 736 }); 737 } 738}; 739 740REGISTER_KERNEL_BUILDER( 741 Name("QuantileAccumulatorGetBuckets").Device(DEVICE_CPU), 742 QuantileAccumulatorGetBucketsOp); 743 744// Generates buckets for given set of float values, and the given config. 745class QuantileBucketsOp : public OpKernel { 746 public: 747 explicit QuantileBucketsOp(OpKernelConstruction* const context) 748 : OpKernel(context) { 749 OP_REQUIRES_OK(context, 750 ReadAndValidateAttributes(context, &num_dense_features_, 751 &num_sparse_features_)); 752 753 ParseConfig(context, kDenseConfigName, &dense_configs_); 754 OP_REQUIRES(context, dense_configs_.size() == num_dense_features_, 755 errors::InvalidArgument( 756 "Mismatch in number of dense quantile configs.")); 757 ParseConfig(context, kSparseConfigName, &sparse_configs_); 758 OP_REQUIRES(context, sparse_configs_.size() == num_sparse_features_, 759 errors::InvalidArgument( 760 "Mismatch in number of sparse quantile configs.")); 761 } 762 763 void Compute(OpKernelContext* const context) override { 764 // Read dense float features list; 765 OpInputList dense_float_features_list; 766 OP_REQUIRES_OK(context, TensorUtils::ReadDenseFloatFeatures( 767 context, &dense_float_features_list)); 768 769 // Read sparse float features list; 770 OpInputList sparse_float_feature_indices_list; 771 OpInputList sparse_float_feature_values_list; 772 OpInputList sparse_float_feature_shapes_list; 773 OP_REQUIRES_OK(context, TensorUtils::ReadSparseFloatFeatures( 774 context, &sparse_float_feature_indices_list, 775 &sparse_float_feature_values_list, 776 &sparse_float_feature_shapes_list)); 777 778 // Parse example weights and get batch size. 779 const Tensor* example_weights_t; 780 OP_REQUIRES_OK(context, 781 context->input(kExampleWeightsName, &example_weights_t)); 782 auto example_weights = example_weights_t->flat<float>(); 783 const int64 batch_size = example_weights.size(); 784 785 OpOutputList sparse_buckets_output_list; 786 OP_REQUIRES_OK(context, context->output_list(kSparseBucketsName, 787 &sparse_buckets_output_list)); 788 OpOutputList dense_buckets_output_list; 789 OP_REQUIRES_OK(context, context->output_list(kDenseBucketsName, 790 &dense_buckets_output_list)); 791 792 auto do_quantile_bucket_gen = [&](const int64 begin, const int64 end) { 793 // These are blocks of ranges. We are iterating over both sparse and 794 // dense features i.e. [0, sparse_features.size() + dense_features.size()] 795 for (int64 i = begin; i < end; ++i) { 796 if (i < sparse_configs_.size()) { 797 const int64 sparse_index = i; 798 const auto sparse_values = 799 sparse_float_feature_values_list[sparse_index].flat<float>(); 800 const auto sparse_indices = 801 sparse_float_feature_indices_list[sparse_index].matrix<int64>(); 802 QuantileStream stream(sparse_configs_[sparse_index].eps(), 803 batch_size); 804 // Run quantile summary generation. 805 const int64 num_sparse_rows = 806 sparse_float_feature_indices_list[sparse_index].dim_size(0); 807 for (int64 j = 0; j < num_sparse_rows; ++j) { 808 const int64 example_id = sparse_indices(j, 0); 809 stream.PushEntry(sparse_values(j), example_weights(example_id)); 810 } 811 stream.Finalize(); 812 // Create buckets. 813 const auto boundaries = GenerateBoundaries( 814 stream, sparse_configs_[sparse_index].num_quantiles()); 815 CopyBoundaries(context, boundaries, sparse_index, 816 &sparse_buckets_output_list); 817 818 } else { 819 const int64 dense_index = i - sparse_configs_.size(); 820 const auto dense_values = 821 dense_float_features_list[dense_index].flat<float>(); 822 QuantileStream stream(dense_configs_[dense_index].eps(), batch_size); 823 // Run quantile summary generation. 824 for (int64 j = 0; j < batch_size; ++j) { 825 stream.PushEntry(dense_values(j), example_weights(j)); 826 } 827 stream.Finalize(); 828 // Create buckets. 829 const auto boundaries = GenerateBoundaries( 830 stream, dense_configs_[dense_index].num_quantiles()); 831 CopyBoundaries(context, boundaries, dense_index, 832 &dense_buckets_output_list); 833 } 834 } 835 }; 836 837 const int64 kCostPerUnit = 500 * batch_size; 838 const int64 num_features = sparse_configs_.size() + dense_configs_.size(); 839 const DeviceBase::CpuWorkerThreads& worker_threads = 840 *context->device()->tensorflow_cpu_worker_threads(); 841 Shard(worker_threads.num_threads, worker_threads.workers, num_features, 842 kCostPerUnit, do_quantile_bucket_gen); 843 } 844 845 private: 846 int num_dense_features_; 847 int num_sparse_features_; 848 std::vector<QuantileConfig> dense_configs_; 849 std::vector<QuantileConfig> sparse_configs_; 850}; 851 852REGISTER_KERNEL_BUILDER(Name("QuantileBuckets").Device(DEVICE_CPU), 853 QuantileBucketsOp); 854 855// Given the calculated quantiles thresholds and input data, this operation 856// converts the input features into the buckets (categorical values), depending 857// on which quantile they fall into. 858class QuantilesOp : public OpKernel { 859 public: 860 explicit QuantilesOp(OpKernelConstruction* const context) 861 : OpKernel(context) { 862 int num_dense_features; 863 int num_sparse_features; 864 OP_REQUIRES_OK(context, 865 ReadAndValidateAttributes(context, &num_dense_features, 866 &num_sparse_features)); 867 } 868 869 void Compute(OpKernelContext* const context) override { 870 // Dense features inputs 871 OpInputList dense_float_features_list; 872 OP_REQUIRES_OK(context, context->input_list(kDenseValuesName, 873 &dense_float_features_list)); 874 OpInputList dense_buckets_list; 875 OP_REQUIRES_OK(context, 876 context->input_list(kDenseBucketsName, &dense_buckets_list)); 877 878 if (dense_buckets_list.size() > 0) { 879 // Check the first tensor to make sure it is the right shape 880 OP_REQUIRES( 881 context, 882 tensorflow::TensorShapeUtils::IsVector(dense_buckets_list[0].shape()), 883 errors::InvalidArgument( 884 strings::Printf("Dense buckets should be flat vectors"))); 885 } 886 887 // Sparse features inputs 888 OpInputList sparse_float_feature_values_list; 889 OP_REQUIRES_OK(context, 890 context->input_list(kSparseValuesName, 891 &sparse_float_feature_values_list)); 892 893 OpInputList sparse_float_indices_list; 894 OP_REQUIRES_OK(context, context->input_list(kSparseIndicesName, 895 &sparse_float_indices_list)); 896 897 OpInputList sparse_buckets_list; 898 OP_REQUIRES_OK( 899 context, context->input_list(kSparseBucketsName, &sparse_buckets_list)); 900 901 if (sparse_buckets_list.size() > 0) { 902 OP_REQUIRES( 903 context, 904 tensorflow::TensorShapeUtils::IsVector( 905 sparse_buckets_list[0].shape()), 906 errors::InvalidArgument("Sparse buckets should be flat vectors")); 907 } 908 909 // Quantize the feature values 910 QuantizeFeatures(kDenseOutputTensorName, dense_float_features_list, 911 dense_buckets_list, nullptr, context); 912 913 QuantizeFeatures(kSparseOutputTensorName, sparse_float_feature_values_list, 914 sparse_buckets_list, &sparse_float_indices_list, context); 915 } 916}; 917 918REGISTER_KERNEL_BUILDER(Name("Quantiles").Device(DEVICE_CPU), QuantilesOp); 919 920template <typename T> 921class BucketizeWithInputBoundariesOp : public OpKernel { 922 public: 923 explicit BucketizeWithInputBoundariesOp(OpKernelConstruction* context) 924 : OpKernel(context) {} 925 926 void Compute(OpKernelContext* context) override { 927 const Tensor& boundaries_tensor = context->input(1); 928 VLOG(1) << "boundaries has shape: " 929 << boundaries_tensor.shape().DebugString(); 930 auto boundaries = boundaries_tensor.flat<float>(); 931 std::vector<T> boundaries_vector; 932 boundaries_vector.reserve(boundaries.size()); 933 for (size_t i = 0; i < boundaries.size(); i++) { 934 boundaries_vector.push_back(boundaries(i)); 935 VLOG(1) << "boundaries(" << i << ") : " << boundaries(i); 936 } 937 OP_REQUIRES( 938 context, 939 std::is_sorted(boundaries_vector.begin(), boundaries_vector.end()), 940 errors::InvalidArgument("Expected sorted boundaries")); 941 942 const Tensor& input_tensor = context->input(0); 943 VLOG(1) << "Inputs has shape: " << input_tensor.shape().DebugString() 944 << " Dtype: " << tensorflow::DataTypeString(input_tensor.dtype()); 945 auto input = input_tensor.flat<T>(); 946 947 Tensor* output_tensor = nullptr; 948 OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), 949 &output_tensor)); 950 auto output = output_tensor->template flat<int32>(); 951 952 for (size_t i = 0; i < input.size(); i++) { 953 output(i) = CalculateBucketIndex(input(i), boundaries_vector); 954 } 955 } 956 957 private: 958 int32 CalculateBucketIndex(const T value, std::vector<T>& boundaries_vector) { 959 auto first_bigger_it = std::upper_bound(boundaries_vector.begin(), 960 boundaries_vector.end(), value); 961 int32 index = first_bigger_it - boundaries_vector.begin(); 962 CHECK(index >= 0 && index <= boundaries_vector.size()) 963 << "Invalid bucket index: " << index 964 << " boundaries_vector.size(): " << boundaries_vector.size(); 965 return index; 966 } 967}; 968 969#define REGISTER_KERNEL(T) \ 970 REGISTER_KERNEL_BUILDER(Name("BucketizeWithInputBoundaries") \ 971 .Device(DEVICE_CPU) \ 972 .TypeConstraint<T>("T"), \ 973 BucketizeWithInputBoundariesOp<T>); 974 975REGISTER_KERNEL(int32); 976REGISTER_KERNEL(int64); 977REGISTER_KERNEL(float); 978REGISTER_KERNEL(double); 979#undef REGISTER_KERNEL 980 981} // namespace tensorflow 982