1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16// Contains OP to generate sparse crosses. 17#include <assert.h> 18#include <limits> 19#include <string> 20#include <vector> 21 22#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 23#include "tensorflow/core/framework/kernel_def_builder.h" 24#include "tensorflow/core/framework/op_def_builder.h" 25#include "tensorflow/core/framework/op_kernel.h" 26#include "tensorflow/core/framework/tensor.h" 27#include "tensorflow/core/framework/tensor_shape.h" 28#include "tensorflow/core/framework/types.h" 29#include "tensorflow/core/lib/core/stringpiece.h" 30#include "tensorflow/core/lib/strings/str_util.h" 31#include "tensorflow/core/platform/fingerprint.h" 32#include "tensorflow/core/util/work_sharder.h" 33 34namespace tensorflow { 35 36namespace { 37// An interface that represents a column with batches. 38template <typename InternalType> 39class ColumnInterface { 40 public: 41 // Returns the number of features in the specified batch. 42 virtual int64 FeatureCount(int64 batch) const = 0; 43 44 // Returns the fingerprint of nth feature from the specified batch. 45 virtual InternalType Feature(int64 batch, int64 n) const = 0; 46 47 virtual ~ColumnInterface() {} 48}; 49 50// A column that is backed by a sparse tensor. 51template <typename InternalType> 52class SparseTensorColumn : public ColumnInterface<InternalType> { 53 public: 54 SparseTensorColumn(const Tensor& values, std::vector<int64> feature_counts, 55 std::vector<int64> feature_start_indices) 56 : values_(values), 57 feature_counts_(std::move(feature_counts)), 58 feature_start_indices_(std::move(feature_start_indices)) { 59 CHECK_EQ(feature_counts_.size(), feature_start_indices_.size()); 60 } 61 62 int64 FeatureCount(int64 batch) const override { 63 return feature_counts_[batch]; 64 } 65 66 InternalType Feature(int64 batch, int64 n) const override; 67 68 ~SparseTensorColumn() override {} 69 70 private: 71 const Tensor& values_; 72 std::vector<int64> feature_counts_; 73 std::vector<int64> feature_start_indices_; 74}; 75 76// InternalType is int64 only when using HashCrosser. 77template <> 78int64 SparseTensorColumn<int64>::Feature(int64 batch, int64 n) const { 79 const int64 start = feature_start_indices_[batch]; 80 if (DT_STRING == values_.dtype()) 81 return Fingerprint64(values_.vec<string>().data()[start + n]); 82 return values_.vec<int64>().data()[start + n]; 83} 84 85// InternalType is string or StringPiece when using StringCrosser. 86template <> 87string SparseTensorColumn<string>::Feature(int64 batch, int64 n) const { 88 const int64 start = feature_start_indices_[batch]; 89 if (DT_STRING == values_.dtype()) 90 return values_.vec<string>().data()[start + n]; 91 return std::to_string(values_.vec<int64>().data()[start + n]); 92} 93 94template <> 95StringPiece SparseTensorColumn<StringPiece>::Feature(int64 batch, 96 int64 n) const { 97 const int64 start = feature_start_indices_[batch]; 98 return values_.vec<string>().data()[start + n]; 99} 100 101// A column that is backed by a dense tensor. 102template <typename InternalType> 103class DenseTensorColumn : public ColumnInterface<InternalType> { 104 public: 105 explicit DenseTensorColumn(const Tensor& tensor) : tensor_(tensor) {} 106 107 int64 FeatureCount(int64 batch) const override { return tensor_.dim_size(1); } 108 109 InternalType Feature(int64 batch, int64 n) const override; 110 111 ~DenseTensorColumn() override {} 112 113 private: 114 const Tensor& tensor_; 115}; 116 117// InternalType is int64 only when using HashCrosser. 118template <> 119int64 DenseTensorColumn<int64>::Feature(int64 batch, int64 n) const { 120 if (DT_STRING == tensor_.dtype()) 121 return Fingerprint64(tensor_.matrix<string>()(batch, n)); 122 return tensor_.matrix<int64>()(batch, n); 123} 124 125// Internal type is string or StringPiece when using StringCrosser. 126template <> 127string DenseTensorColumn<string>::Feature(int64 batch, int64 n) const { 128 if (DT_STRING == tensor_.dtype()) return tensor_.matrix<string>()(batch, n); 129 return std::to_string(tensor_.matrix<int64>()(batch, n)); 130} 131 132template <> 133StringPiece DenseTensorColumn<StringPiece>::Feature(int64 batch, 134 int64 n) const { 135 return tensor_.matrix<string>()(batch, n); 136} 137 138// Updates Output tensors with sparse crosses. 139template <typename OutType> 140class OutputUpdater { 141 public: 142 OutputUpdater(const std::vector<int64>& output_start_indices, 143 Tensor* indices_out, Tensor* values_out) 144 : output_start_indices_(output_start_indices), 145 indices_out_(indices_out), 146 values_out_(values_out) {} 147 148 void Update(const int64 batch_index, const int64 cross_count, 149 const OutType& cross) const { 150 const int64 output_index = output_start_indices_[batch_index] + cross_count; 151 152 auto indices_matrix = indices_out_->matrix<int64>(); 153 indices_matrix(output_index, 0) = batch_index; 154 indices_matrix(output_index, 1) = cross_count; 155 156 auto value_vec = values_out_->vec<OutType>(); 157 value_vec(output_index) = cross; 158 } 159 160 private: 161 const std::vector<int64>& output_start_indices_; 162 Tensor* indices_out_; 163 Tensor* values_out_; 164}; 165 166// Generates the sparse crosses as concatenation of strings. 167template <typename InternalType> 168class StringCrosser { 169 public: 170 StringCrosser(const std::vector< 171 std::unique_ptr<ColumnInterface<InternalType>>>& columns, 172 const int64 num_buckets_unused, const uint64 hash_key_unused) 173 : columns_(columns) {} 174 175 string Generate(const int64 batch_index, 176 const std::vector<int>& permutation) const { 177 static const auto k_feature_separator = "_X_"; 178 179 gtl::InlinedVector<InternalType, 6> cross_vec(columns_.size()); 180 for (size_t i = 0; i < permutation.size(); i++) { 181 cross_vec[i] = columns_[i]->Feature(batch_index, permutation[i]); 182 } 183 // TODO(zakaria): this will copy the string twice, might effect 184 // performance. 185 return str_util::Join(cross_vec, k_feature_separator); 186 } 187 188 private: 189 const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>& columns_; 190}; 191 192// Generates the sparse crosses as nested hash to avoid string manipulations. 193class HashCrosser { 194 public: 195 HashCrosser( 196 const std::vector<std::unique_ptr<ColumnInterface<int64>>>& columns, 197 const int64 num_buckets, const uint64 hash_key_unused) 198 : columns_(columns), num_buckets_(num_buckets) {} 199 200 int64 Generate(const int64 batch_index, 201 const std::vector<int>& permutation) const { 202 // Seed is chosen based on third_party/tensorflow/core/lib/hash/hash.h 203 static const int64 kInitialHashSeed = 0xDECAFCAFFE; 204 205 uint64 hashed_output = kInitialHashSeed; 206 for (size_t i = 0; i < permutation.size(); ++i) { 207 int64 hash_i = columns_[i]->Feature(batch_index, permutation[i]); 208 hashed_output = HashCombine(hashed_output, hash_i); 209 } 210 if (num_buckets_ > 0) { 211 return hashed_output % num_buckets_; 212 } else { 213 // To prevent negative output we take modulo to max int64. 214 return hashed_output % std::numeric_limits<int64>::max(); 215 } 216 } 217 218 private: 219 static int64 HashCombine(int64 a, int64 b) { 220 return a ^ (b + 0x9e3779b97f4a7800 + (a << 10) + (a >> 4)); 221 } 222 223 const std::vector<std::unique_ptr<ColumnInterface<int64>>>& columns_; 224 const int64 num_buckets_; 225}; 226 227// Generates the sparse crosses as nested hash to avoid string manipulations. 228class HashCrosserV2 { 229 public: 230 HashCrosserV2( 231 const std::vector<std::unique_ptr<ColumnInterface<int64>>>& columns, 232 const int64 num_buckets, const uint64 hash_key) 233 : columns_(columns), num_buckets_(num_buckets), hash_key_(hash_key) {} 234 235 int64 Generate(const int64 batch_index, 236 const std::vector<int>& permutation) const { 237 // Do the fingerprint concatenation on uint64. 238 uint64 hashed_output = hash_key_; 239 for (size_t i = 0; i < permutation.size(); ++i) { 240 uint64 hash_i = columns_[i]->Feature(batch_index, permutation[i]); 241 hashed_output = FingerprintCat64(hashed_output, hash_i); 242 } 243 // The return value is int64 based on the number of buckets. 244 if (num_buckets_ > 0) { 245 return hashed_output % num_buckets_; 246 } else { 247 // To prevent negative output we take modulo to max int64. 248 return hashed_output % std::numeric_limits<int64>::max(); 249 } 250 } 251 252 private: 253 const std::vector<std::unique_ptr<ColumnInterface<int64>>>& columns_; 254 const int64 num_buckets_; 255 const uint64 hash_key_; 256}; 257 258// ProductIterator generates cartesian products based on indices. 259template <typename InternalType> 260class ProductIterator { 261 public: 262 explicit ProductIterator( 263 const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>& 264 columns, 265 int64 batch_index) 266 : columns_(columns), batch_index_(batch_index) { 267 next_permutation_.resize(columns_.size(), 0); 268 // Sets has_next_ to false if any feature column has 0 features. 269 has_next_ = true; 270 for (size_t i = 0; i < columns_.size(); i++) { 271 if (columns_[i]->FeatureCount(batch_index_) == 0) { 272 has_next_ = false; 273 break; 274 } 275 } 276 } 277 278 std::vector<int> Next() { 279 std::vector<int> permutation(next_permutation_); 280 281 // Generates next permutation, if available. 282 bool carry = true; 283 for (int i = next_permutation_.size() - 1; i >= 0; i--) { 284 if (carry) { 285 next_permutation_[i] = next_permutation_[i] + 1; 286 } 287 if (next_permutation_[i] == columns_[i]->FeatureCount(batch_index_)) { 288 next_permutation_[i] = 0; 289 } else { 290 carry = false; 291 break; 292 } 293 } 294 has_next_ = !carry; 295 return permutation; 296 } 297 298 bool HasNext() { return has_next_; } 299 300 private: 301 bool has_next_; 302 const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>& columns_; 303 const int64 batch_index_; 304 std::vector<int> next_permutation_; 305}; 306 307template <bool HASHED_OUTPUT, typename InternalType, bool VERSION_2> 308struct CrossTraits; 309 310template <typename InternalType, bool VERSION_2> 311struct CrossTraits<false, InternalType, VERSION_2> { 312 typedef StringCrosser<InternalType> Crosser; 313 typedef OutputUpdater<string> Updater; 314}; 315 316template <> 317struct CrossTraits<true, int64, false> { 318 typedef HashCrosser Crosser; 319 typedef OutputUpdater<int64> Updater; 320}; 321 322template <> 323struct CrossTraits<true, int64, true> { 324 typedef HashCrosserV2 Crosser; 325 typedef OutputUpdater<int64> Updater; 326}; 327} // namespace 328 329template <bool HASHED_OUTPUT, typename InternalType, bool VERSION_2> 330class SparseFeatureCrossOp : public OpKernel { 331 public: 332 explicit SparseFeatureCrossOp(OpKernelConstruction* context) 333 : OpKernel(context) { 334 OP_REQUIRES_OK(context, context->GetAttr("num_buckets", &num_buckets_)); 335 if (VERSION_2) { 336 // Read signed_hash_key_ as int64 since uint64 attributes are not 337 // supported by REGISTER_OP. 338 int64 signed_hash_key_; 339 OP_REQUIRES_OK(context, context->GetAttr("hash_key", &signed_hash_key_)); 340 hash_key_ = static_cast<uint64>(signed_hash_key_); 341 } 342 } 343 344 void Compute(OpKernelContext* context) override { 345 OpInputList indices_list_in; 346 OP_REQUIRES_OK(context, context->input_list("indices", &indices_list_in)); 347 OpInputList values_list_in; 348 OP_REQUIRES_OK(context, context->input_list("values", &values_list_in)); 349 OpInputList shapes_list_in; 350 OP_REQUIRES_OK(context, context->input_list("shapes", &shapes_list_in)); 351 OpInputList dense_list_in; 352 OP_REQUIRES_OK(context, context->input_list("dense", &dense_list_in)); 353 354 ValidateInput(context, indices_list_in, values_list_in, shapes_list_in, 355 dense_list_in); 356 357 std::vector<std::unique_ptr<ColumnInterface<InternalType>>> columns = 358 GenerateColumnsFromInput(indices_list_in, values_list_in, 359 shapes_list_in, dense_list_in); 360 361 typename CrossTraits<HASHED_OUTPUT, InternalType, VERSION_2>::Crosser 362 crosser(columns, num_buckets_, hash_key_); 363 Tensor* indices_out; 364 Tensor* values_out; 365 Tensor* shape_out; 366 const int64 batch_size = CalculateBatchSize(shapes_list_in, dense_list_in); 367 std::vector<int64> output_start_indices(batch_size); 368 CreateOutputTensors(columns, batch_size, context, &indices_out, &values_out, 369 &shape_out, &output_start_indices); 370 371 typename CrossTraits<HASHED_OUTPUT, InternalType, VERSION_2>::Updater 372 updater(output_start_indices, indices_out, values_out); 373 auto do_work = [this, &columns, crosser, updater](int64 begin, int64 end) { 374 for (int b = begin; b < end; b++) { 375 ProductIterator<InternalType> product_iterator(columns, b); 376 int64 cross_count = 0; 377 while (product_iterator.HasNext()) { 378 const auto permutation = product_iterator.Next(); 379 updater.Update(b, cross_count, crosser.Generate(b, permutation)); 380 cross_count++; 381 } 382 } 383 }; 384 385 auto* worker_threads = context->device()->tensorflow_cpu_worker_threads(); 386 // TODO(zakaria): optimize kCostPerUnit 387 const int kCostPerUnit = 5000 * indices_list_in.size(); 388 Shard(worker_threads->num_threads, worker_threads->workers, batch_size, 389 kCostPerUnit, do_work); 390 } 391 392 private: 393 // Validates input tensors. 394 void ValidateInput(OpKernelContext* context, 395 const OpInputList& indices_list_in, 396 const OpInputList& values_list_in, 397 const OpInputList& shapes_list_in, 398 const OpInputList& dense_list_in) { 399 const auto size = indices_list_in.size(); 400 // Validates indices_list_in OpInputList. 401 for (int i = 0; i < size; i++) { 402 OP_REQUIRES( 403 context, TensorShapeUtils::IsMatrix(indices_list_in[i].shape()), 404 errors::InvalidArgument( 405 "Input indices should be a matrix but received shape ", 406 indices_list_in[i].shape().DebugString(), " at position ", i)); 407 OP_REQUIRES( 408 context, indices_list_in[i].shape().dim_size(1) == 2, 409 errors::InvalidArgument("Expected D2 of index to be 2 got ", 410 indices_list_in[i].shape().dim_size(1), 411 " at position ", i)); 412 } 413 414 // Validates values_list_in OpInputList. 415 OP_REQUIRES( 416 context, values_list_in.size() == size, 417 errors::InvalidArgument("Expected ", size, " input values, got ", 418 values_list_in.size())); 419 for (int i = 0; i < size; i++) { 420 OP_REQUIRES( 421 context, TensorShapeUtils::IsVector(values_list_in[i].shape()), 422 errors::InvalidArgument( 423 "Input values should be a std::vector but received shape ", 424 values_list_in[i].shape().DebugString(), " at position ", i)); 425 OP_REQUIRES( 426 context, 427 indices_list_in[i].shape().dim_size(0) == 428 values_list_in[i].shape().dim_size(0), 429 errors::InvalidArgument( 430 "Expected size of values to be ", 431 indices_list_in[i].shape().dim_size(0), " got ", 432 values_list_in[i].shape().dim_size(0), " at position ", i)); 433 } 434 435 // Validates shapes_list_in OpInputList 436 OP_REQUIRES( 437 context, shapes_list_in.size() == size, 438 errors::InvalidArgument("Expected ", size, " input shapes, got ", 439 shapes_list_in.size())); 440 const auto batch_size = CalculateBatchSize(shapes_list_in, dense_list_in); 441 for (int i = 0; i < size; i++) { 442 OP_REQUIRES( 443 context, TensorShapeUtils::IsVector(shapes_list_in[i].shape()), 444 errors::InvalidArgument( 445 "Input shapes should be a std::vector but received shape ", 446 shapes_list_in[i].shape().DebugString(), " at position ", i)); 447 448 OP_REQUIRES( 449 context, shapes_list_in[i].vec<int64>().size() == 2, 450 errors::InvalidArgument("shape should imply a 2D tensor, but got ", 451 shapes_list_in[i].shape().DebugString(), 452 " at position ", i)); 453 OP_REQUIRES(context, shapes_list_in[i].vec<int64>()(0) == batch_size, 454 errors::InvalidArgument( 455 "Expected batch size ", batch_size, " got ", 456 shapes_list_in[i].vec<int64>()(0), " at position ", i)); 457 } 458 459 // Validates dense_list_in OpInputList 460 for (int i = 0; i < dense_list_in.size(); ++i) { 461 OP_REQUIRES( 462 context, TensorShapeUtils::IsMatrix(dense_list_in[i].shape()), 463 errors::InvalidArgument( 464 "Dense inputs should be a matrix but received shape ", 465 indices_list_in[i].shape().DebugString(), " at position ", i)); 466 OP_REQUIRES(context, dense_list_in[i].dim_size(0) == batch_size, 467 errors::InvalidArgument("Expected batch size ", batch_size, 468 " got ", dense_list_in[i].dim_size(0), 469 " at dense tensor ", i)); 470 } 471 } 472 473 // Calculate the batch size from either the shapes input or the dense input. 474 int64 CalculateBatchSize(const OpInputList& shapes_list_in, 475 const OpInputList& dense_list_in) { 476 if (shapes_list_in.size() > 0) { 477 return shapes_list_in[0].vec<int64>()(0); 478 } 479 480 if (dense_list_in.size() > 0) { 481 return dense_list_in[0].dim_size(0); 482 } 483 484 return 0; 485 } 486 487 // Generate the columns given the sparse and dense inputs. 488 std::vector<std::unique_ptr<ColumnInterface<InternalType>>> 489 GenerateColumnsFromInput(const OpInputList& indices_list_in, 490 const OpInputList& values_list_in, 491 const OpInputList& shapes_list_in, 492 const OpInputList& dense_list_in) { 493 std::vector<std::unique_ptr<ColumnInterface<InternalType>>> columns; 494 const int64 batch_size = CalculateBatchSize(shapes_list_in, dense_list_in); 495 const int64 number_of_columns = shapes_list_in.size(); 496 497 std::vector<std::vector<int64>> feature_counts(number_of_columns, 498 std::vector<int64>()); 499 std::vector<std::vector<int64>> feature_start_indices(number_of_columns, 500 std::vector<int64>()); 501 502 ExtractFeatureData(indices_list_in, batch_size, &feature_counts, 503 &feature_start_indices); 504 505 columns.reserve(values_list_in.size()); 506 for (int i = 0; i < values_list_in.size(); ++i) { 507 columns.emplace_back(new SparseTensorColumn<InternalType>( 508 values_list_in[i], std::move(feature_counts[i]), 509 std::move(feature_start_indices[i]))); 510 } 511 for (int i = 0; i < dense_list_in.size(); ++i) { 512 columns.emplace_back( 513 new DenseTensorColumn<InternalType>(dense_list_in[i])); 514 } 515 516 return columns; 517 } 518 519 // Extracts data about the features and populates feature data. 520 void ExtractFeatureData( 521 const OpInputList& indices_list_in, int64 batch_size, 522 std::vector<std::vector<int64>>* feature_counts, 523 std::vector<std::vector<int64>>* feature_start_indices) { 524 gtl::InlinedVector<int64, 8> current_row(indices_list_in.size(), 0); 525 for (int b = 0; b < batch_size; b++) { 526 for (int i = 0; i < indices_list_in.size(); i++) { 527 const auto indices = indices_list_in[i].matrix<int64>(); 528 int64 feature_count = 0; 529 int64 start_index = current_row[i]; 530 // Loops until we reach next batch index for current feature column. 531 while (current_row[i] < indices_list_in[i].dim_size(0) && 532 indices(current_row[i], 0) == b) { 533 feature_count++; 534 current_row[i]++; 535 } 536 (*feature_counts)[i].push_back(feature_count); 537 (*feature_start_indices)[i].push_back(start_index); 538 } 539 } 540 } 541 542 // Allocates output tensors with proper size and sets the shape tensor of 543 // the output SparseTensor. 544 // It also output_start_indices which contains the start indices for each 545 // input in the output SparseTensor. 546 void CreateOutputTensors( 547 const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>& 548 columns, 549 int64 batch_size, OpKernelContext* context, Tensor** indices_out, 550 Tensor** values_out, Tensor** shape_out, 551 std::vector<int64>* output_start_indices) { 552 // Calculates dimensions for output tensors. 553 int64 cross_count_total = 0; 554 int64 max_cross_count = 0; 555 for (int64 b = 0; b < batch_size; b++) { 556 // For each input, sets starting indices in output SparseTensor 557 (*output_start_indices)[b] = cross_count_total; 558 const auto cross_count = CrossCountByBatchIndex(columns, b); 559 max_cross_count = std::max(max_cross_count, cross_count); 560 cross_count_total += cross_count; 561 } 562 563 // Allocates tensors. 564 OP_REQUIRES_OK(context, 565 context->allocate_output( 566 0, TensorShape({cross_count_total, 2}), indices_out)); 567 OP_REQUIRES_OK(context, 568 context->allocate_output(1, TensorShape({cross_count_total}), 569 values_out)); 570 OP_REQUIRES_OK(context, 571 context->allocate_output(2, TensorShape({2}), shape_out)); 572 573 // Sets shape. 574 auto shape_vec = (*shape_out)->vec<int64>(); 575 shape_vec(0) = batch_size; 576 shape_vec(1) = max_cross_count; 577 } 578 579 // Returns number of crosses for a given batch_index 580 int64 CrossCountByBatchIndex( 581 const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>& 582 columns, 583 int batch_index) { 584 int64 cross_count = 1; 585 for (size_t i = 0; i < columns.size(); i++) { 586 const auto feature_count = columns[i]->FeatureCount(batch_index); 587 // If one column is missing any feature, there won't be any cross. 588 if (feature_count == 0) { 589 return 0; 590 } 591 cross_count *= feature_count; 592 } 593 return cross_count; 594 } 595 int64 num_buckets_; 596 uint64 hash_key_; 597}; 598 599REGISTER_KERNEL_BUILDER(Name("SparseFeatureCross") 600 .Device(DEVICE_CPU) 601 .TypeConstraint<string>("out_type") 602 .TypeConstraint<string>("internal_type"), 603 SparseFeatureCrossOp<false, StringPiece, false>); 604 605REGISTER_KERNEL_BUILDER(Name("SparseFeatureCross") 606 .Device(DEVICE_CPU) 607 .TypeConstraint<string>("out_type") 608 .TypeConstraint<int64>("internal_type"), 609 SparseFeatureCrossOp<false, string, false>); 610 611REGISTER_KERNEL_BUILDER(Name("SparseFeatureCross") 612 .Device(DEVICE_CPU) 613 .TypeConstraint<int64>("out_type") 614 .TypeConstraint<string>("internal_type"), 615 SparseFeatureCrossOp<true, int64, false>); 616 617REGISTER_KERNEL_BUILDER(Name("SparseFeatureCross") 618 .Device(DEVICE_CPU) 619 .TypeConstraint<int64>("out_type") 620 .TypeConstraint<int64>("internal_type"), 621 SparseFeatureCrossOp<true, int64, false>); 622 623// The following builders enable FingerprintCat64 concatenation for the 624// crosses features. 625REGISTER_KERNEL_BUILDER(Name("SparseFeatureCrossV2") 626 .Device(DEVICE_CPU) 627 .TypeConstraint<string>("out_type") 628 .TypeConstraint<string>("internal_type"), 629 SparseFeatureCrossOp<false, StringPiece, true>); 630 631REGISTER_KERNEL_BUILDER(Name("SparseFeatureCrossV2") 632 .Device(DEVICE_CPU) 633 .TypeConstraint<string>("out_type") 634 .TypeConstraint<int64>("internal_type"), 635 SparseFeatureCrossOp<false, string, true>); 636 637REGISTER_KERNEL_BUILDER(Name("SparseFeatureCrossV2") 638 .Device(DEVICE_CPU) 639 .TypeConstraint<int64>("out_type") 640 .TypeConstraint<string>("internal_type"), 641 SparseFeatureCrossOp<true, int64, true>); 642 643REGISTER_KERNEL_BUILDER(Name("SparseFeatureCrossV2") 644 .Device(DEVICE_CPU) 645 .TypeConstraint<int64>("out_type") 646 .TypeConstraint<int64>("internal_type"), 647 SparseFeatureCrossOp<true, int64, true>); 648 649} // namespace tensorflow 650