1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// Contains OP to generate sparse crosses.
17#include <assert.h>
18#include <limits>
19#include <string>
20#include <vector>
21
22#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
23#include "tensorflow/core/framework/kernel_def_builder.h"
24#include "tensorflow/core/framework/op_def_builder.h"
25#include "tensorflow/core/framework/op_kernel.h"
26#include "tensorflow/core/framework/tensor.h"
27#include "tensorflow/core/framework/tensor_shape.h"
28#include "tensorflow/core/framework/types.h"
29#include "tensorflow/core/lib/core/stringpiece.h"
30#include "tensorflow/core/lib/strings/str_util.h"
31#include "tensorflow/core/platform/fingerprint.h"
32#include "tensorflow/core/util/work_sharder.h"
33
34namespace tensorflow {
35
36namespace {
37// An interface that represents a column with batches.
38template <typename InternalType>
39class ColumnInterface {
40 public:
41  // Returns the number of features in the specified batch.
42  virtual int64 FeatureCount(int64 batch) const = 0;
43
44  // Returns the fingerprint of nth feature from the specified batch.
45  virtual InternalType Feature(int64 batch, int64 n) const = 0;
46
47  virtual ~ColumnInterface() {}
48};
49
50// A column that is backed by a sparse tensor.
51template <typename InternalType>
52class SparseTensorColumn : public ColumnInterface<InternalType> {
53 public:
54  SparseTensorColumn(const Tensor& values, std::vector<int64> feature_counts,
55                     std::vector<int64> feature_start_indices)
56      : values_(values),
57        feature_counts_(std::move(feature_counts)),
58        feature_start_indices_(std::move(feature_start_indices)) {
59    CHECK_EQ(feature_counts_.size(), feature_start_indices_.size());
60  }
61
62  int64 FeatureCount(int64 batch) const override {
63    return feature_counts_[batch];
64  }
65
66  InternalType Feature(int64 batch, int64 n) const override;
67
68  ~SparseTensorColumn() override {}
69
70 private:
71  const Tensor& values_;
72  std::vector<int64> feature_counts_;
73  std::vector<int64> feature_start_indices_;
74};
75
76// InternalType is int64 only when using HashCrosser.
77template <>
78int64 SparseTensorColumn<int64>::Feature(int64 batch, int64 n) const {
79  const int64 start = feature_start_indices_[batch];
80  if (DT_STRING == values_.dtype())
81    return Fingerprint64(values_.vec<string>().data()[start + n]);
82  return values_.vec<int64>().data()[start + n];
83}
84
85// InternalType is string or StringPiece when using StringCrosser.
86template <>
87string SparseTensorColumn<string>::Feature(int64 batch, int64 n) const {
88  const int64 start = feature_start_indices_[batch];
89  if (DT_STRING == values_.dtype())
90    return values_.vec<string>().data()[start + n];
91  return std::to_string(values_.vec<int64>().data()[start + n]);
92}
93
94template <>
95StringPiece SparseTensorColumn<StringPiece>::Feature(int64 batch,
96                                                     int64 n) const {
97  const int64 start = feature_start_indices_[batch];
98  return values_.vec<string>().data()[start + n];
99}
100
101// A column that is backed by a dense tensor.
102template <typename InternalType>
103class DenseTensorColumn : public ColumnInterface<InternalType> {
104 public:
105  explicit DenseTensorColumn(const Tensor& tensor) : tensor_(tensor) {}
106
107  int64 FeatureCount(int64 batch) const override { return tensor_.dim_size(1); }
108
109  InternalType Feature(int64 batch, int64 n) const override;
110
111  ~DenseTensorColumn() override {}
112
113 private:
114  const Tensor& tensor_;
115};
116
117// InternalType is int64 only when using HashCrosser.
118template <>
119int64 DenseTensorColumn<int64>::Feature(int64 batch, int64 n) const {
120  if (DT_STRING == tensor_.dtype())
121    return Fingerprint64(tensor_.matrix<string>()(batch, n));
122  return tensor_.matrix<int64>()(batch, n);
123}
124
125// Internal type is string or StringPiece when using StringCrosser.
126template <>
127string DenseTensorColumn<string>::Feature(int64 batch, int64 n) const {
128  if (DT_STRING == tensor_.dtype()) return tensor_.matrix<string>()(batch, n);
129  return std::to_string(tensor_.matrix<int64>()(batch, n));
130}
131
132template <>
133StringPiece DenseTensorColumn<StringPiece>::Feature(int64 batch,
134                                                    int64 n) const {
135  return tensor_.matrix<string>()(batch, n);
136}
137
138// Updates Output tensors with sparse crosses.
139template <typename OutType>
140class OutputUpdater {
141 public:
142  OutputUpdater(const std::vector<int64>& output_start_indices,
143                Tensor* indices_out, Tensor* values_out)
144      : output_start_indices_(output_start_indices),
145        indices_out_(indices_out),
146        values_out_(values_out) {}
147
148  void Update(const int64 batch_index, const int64 cross_count,
149              const OutType& cross) const {
150    const int64 output_index = output_start_indices_[batch_index] + cross_count;
151
152    auto indices_matrix = indices_out_->matrix<int64>();
153    indices_matrix(output_index, 0) = batch_index;
154    indices_matrix(output_index, 1) = cross_count;
155
156    auto value_vec = values_out_->vec<OutType>();
157    value_vec(output_index) = cross;
158  }
159
160 private:
161  const std::vector<int64>& output_start_indices_;
162  Tensor* indices_out_;
163  Tensor* values_out_;
164};
165
166// Generates the sparse crosses as concatenation of strings.
167template <typename InternalType>
168class StringCrosser {
169 public:
170  StringCrosser(const std::vector<
171                    std::unique_ptr<ColumnInterface<InternalType>>>& columns,
172                const int64 num_buckets_unused, const uint64 hash_key_unused)
173      : columns_(columns) {}
174
175  string Generate(const int64 batch_index,
176                  const std::vector<int>& permutation) const {
177    static const auto k_feature_separator = "_X_";
178
179    gtl::InlinedVector<InternalType, 6> cross_vec(columns_.size());
180    for (size_t i = 0; i < permutation.size(); i++) {
181      cross_vec[i] = columns_[i]->Feature(batch_index, permutation[i]);
182    }
183    // TODO(zakaria): this will copy the string twice, might effect
184    // performance.
185    return str_util::Join(cross_vec, k_feature_separator);
186  }
187
188 private:
189  const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>& columns_;
190};
191
192// Generates the sparse crosses as nested hash to avoid string manipulations.
193class HashCrosser {
194 public:
195  HashCrosser(
196      const std::vector<std::unique_ptr<ColumnInterface<int64>>>& columns,
197      const int64 num_buckets, const uint64 hash_key_unused)
198      : columns_(columns), num_buckets_(num_buckets) {}
199
200  int64 Generate(const int64 batch_index,
201                 const std::vector<int>& permutation) const {
202    // Seed is chosen based on third_party/tensorflow/core/lib/hash/hash.h
203    static const int64 kInitialHashSeed = 0xDECAFCAFFE;
204
205    uint64 hashed_output = kInitialHashSeed;
206    for (size_t i = 0; i < permutation.size(); ++i) {
207      int64 hash_i = columns_[i]->Feature(batch_index, permutation[i]);
208      hashed_output = HashCombine(hashed_output, hash_i);
209    }
210    if (num_buckets_ > 0) {
211      return hashed_output % num_buckets_;
212    } else {
213      // To prevent negative output we take modulo to max int64.
214      return hashed_output % std::numeric_limits<int64>::max();
215    }
216  }
217
218 private:
219  static int64 HashCombine(int64 a, int64 b) {
220    return a ^ (b + 0x9e3779b97f4a7800 + (a << 10) + (a >> 4));
221  }
222
223  const std::vector<std::unique_ptr<ColumnInterface<int64>>>& columns_;
224  const int64 num_buckets_;
225};
226
227// Generates the sparse crosses as nested hash to avoid string manipulations.
228class HashCrosserV2 {
229 public:
230  HashCrosserV2(
231      const std::vector<std::unique_ptr<ColumnInterface<int64>>>& columns,
232      const int64 num_buckets, const uint64 hash_key)
233      : columns_(columns), num_buckets_(num_buckets), hash_key_(hash_key) {}
234
235  int64 Generate(const int64 batch_index,
236                 const std::vector<int>& permutation) const {
237    // Do the fingerprint concatenation on uint64.
238    uint64 hashed_output = hash_key_;
239    for (size_t i = 0; i < permutation.size(); ++i) {
240      uint64 hash_i = columns_[i]->Feature(batch_index, permutation[i]);
241      hashed_output = FingerprintCat64(hashed_output, hash_i);
242    }
243    // The return value is int64 based on the number of buckets.
244    if (num_buckets_ > 0) {
245      return hashed_output % num_buckets_;
246    } else {
247      // To prevent negative output we take modulo to max int64.
248      return hashed_output % std::numeric_limits<int64>::max();
249    }
250  }
251
252 private:
253  const std::vector<std::unique_ptr<ColumnInterface<int64>>>& columns_;
254  const int64 num_buckets_;
255  const uint64 hash_key_;
256};
257
258// ProductIterator generates cartesian products based on indices.
259template <typename InternalType>
260class ProductIterator {
261 public:
262  explicit ProductIterator(
263      const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>&
264          columns,
265      int64 batch_index)
266      : columns_(columns), batch_index_(batch_index) {
267    next_permutation_.resize(columns_.size(), 0);
268    // Sets has_next_ to false if any feature column has 0 features.
269    has_next_ = true;
270    for (size_t i = 0; i < columns_.size(); i++) {
271      if (columns_[i]->FeatureCount(batch_index_) == 0) {
272        has_next_ = false;
273        break;
274      }
275    }
276  }
277
278  std::vector<int> Next() {
279    std::vector<int> permutation(next_permutation_);
280
281    // Generates next permutation, if available.
282    bool carry = true;
283    for (int i = next_permutation_.size() - 1; i >= 0; i--) {
284      if (carry) {
285        next_permutation_[i] = next_permutation_[i] + 1;
286      }
287      if (next_permutation_[i] == columns_[i]->FeatureCount(batch_index_)) {
288        next_permutation_[i] = 0;
289      } else {
290        carry = false;
291        break;
292      }
293    }
294    has_next_ = !carry;
295    return permutation;
296  }
297
298  bool HasNext() { return has_next_; }
299
300 private:
301  bool has_next_;
302  const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>& columns_;
303  const int64 batch_index_;
304  std::vector<int> next_permutation_;
305};
306
307template <bool HASHED_OUTPUT, typename InternalType, bool VERSION_2>
308struct CrossTraits;
309
310template <typename InternalType, bool VERSION_2>
311struct CrossTraits<false, InternalType, VERSION_2> {
312  typedef StringCrosser<InternalType> Crosser;
313  typedef OutputUpdater<string> Updater;
314};
315
316template <>
317struct CrossTraits<true, int64, false> {
318  typedef HashCrosser Crosser;
319  typedef OutputUpdater<int64> Updater;
320};
321
322template <>
323struct CrossTraits<true, int64, true> {
324  typedef HashCrosserV2 Crosser;
325  typedef OutputUpdater<int64> Updater;
326};
327}  // namespace
328
329template <bool HASHED_OUTPUT, typename InternalType, bool VERSION_2>
330class SparseFeatureCrossOp : public OpKernel {
331 public:
332  explicit SparseFeatureCrossOp(OpKernelConstruction* context)
333      : OpKernel(context) {
334    OP_REQUIRES_OK(context, context->GetAttr("num_buckets", &num_buckets_));
335    if (VERSION_2) {
336      // Read signed_hash_key_ as int64 since uint64 attributes are not
337      // supported by REGISTER_OP.
338      int64 signed_hash_key_;
339      OP_REQUIRES_OK(context, context->GetAttr("hash_key", &signed_hash_key_));
340      hash_key_ = static_cast<uint64>(signed_hash_key_);
341    }
342  }
343
344  void Compute(OpKernelContext* context) override {
345    OpInputList indices_list_in;
346    OP_REQUIRES_OK(context, context->input_list("indices", &indices_list_in));
347    OpInputList values_list_in;
348    OP_REQUIRES_OK(context, context->input_list("values", &values_list_in));
349    OpInputList shapes_list_in;
350    OP_REQUIRES_OK(context, context->input_list("shapes", &shapes_list_in));
351    OpInputList dense_list_in;
352    OP_REQUIRES_OK(context, context->input_list("dense", &dense_list_in));
353
354    ValidateInput(context, indices_list_in, values_list_in, shapes_list_in,
355                  dense_list_in);
356
357    std::vector<std::unique_ptr<ColumnInterface<InternalType>>> columns =
358        GenerateColumnsFromInput(indices_list_in, values_list_in,
359                                 shapes_list_in, dense_list_in);
360
361    typename CrossTraits<HASHED_OUTPUT, InternalType, VERSION_2>::Crosser
362        crosser(columns, num_buckets_, hash_key_);
363    Tensor* indices_out;
364    Tensor* values_out;
365    Tensor* shape_out;
366    const int64 batch_size = CalculateBatchSize(shapes_list_in, dense_list_in);
367    std::vector<int64> output_start_indices(batch_size);
368    CreateOutputTensors(columns, batch_size, context, &indices_out, &values_out,
369                        &shape_out, &output_start_indices);
370
371    typename CrossTraits<HASHED_OUTPUT, InternalType, VERSION_2>::Updater
372        updater(output_start_indices, indices_out, values_out);
373    auto do_work = [this, &columns, crosser, updater](int64 begin, int64 end) {
374      for (int b = begin; b < end; b++) {
375        ProductIterator<InternalType> product_iterator(columns, b);
376        int64 cross_count = 0;
377        while (product_iterator.HasNext()) {
378          const auto permutation = product_iterator.Next();
379          updater.Update(b, cross_count, crosser.Generate(b, permutation));
380          cross_count++;
381        }
382      }
383    };
384
385    auto* worker_threads = context->device()->tensorflow_cpu_worker_threads();
386    // TODO(zakaria): optimize kCostPerUnit
387    const int kCostPerUnit = 5000 * indices_list_in.size();
388    Shard(worker_threads->num_threads, worker_threads->workers, batch_size,
389          kCostPerUnit, do_work);
390  }
391
392 private:
393  // Validates input tensors.
394  void ValidateInput(OpKernelContext* context,
395                     const OpInputList& indices_list_in,
396                     const OpInputList& values_list_in,
397                     const OpInputList& shapes_list_in,
398                     const OpInputList& dense_list_in) {
399    const auto size = indices_list_in.size();
400    // Validates indices_list_in OpInputList.
401    for (int i = 0; i < size; i++) {
402      OP_REQUIRES(
403          context, TensorShapeUtils::IsMatrix(indices_list_in[i].shape()),
404          errors::InvalidArgument(
405              "Input indices should be a matrix but received shape ",
406              indices_list_in[i].shape().DebugString(), " at position ", i));
407      OP_REQUIRES(
408          context, indices_list_in[i].shape().dim_size(1) == 2,
409          errors::InvalidArgument("Expected D2 of index to be 2 got ",
410                                  indices_list_in[i].shape().dim_size(1),
411                                  " at position ", i));
412    }
413
414    // Validates values_list_in OpInputList.
415    OP_REQUIRES(
416        context, values_list_in.size() == size,
417        errors::InvalidArgument("Expected ", size, " input values, got ",
418                                values_list_in.size()));
419    for (int i = 0; i < size; i++) {
420      OP_REQUIRES(
421          context, TensorShapeUtils::IsVector(values_list_in[i].shape()),
422          errors::InvalidArgument(
423              "Input values should be a std::vector but received shape ",
424              values_list_in[i].shape().DebugString(), " at position ", i));
425      OP_REQUIRES(
426          context,
427          indices_list_in[i].shape().dim_size(0) ==
428              values_list_in[i].shape().dim_size(0),
429          errors::InvalidArgument(
430              "Expected size of values to be ",
431              indices_list_in[i].shape().dim_size(0), " got ",
432              values_list_in[i].shape().dim_size(0), " at position ", i));
433    }
434
435    // Validates shapes_list_in OpInputList
436    OP_REQUIRES(
437        context, shapes_list_in.size() == size,
438        errors::InvalidArgument("Expected ", size, " input shapes, got ",
439                                shapes_list_in.size()));
440    const auto batch_size = CalculateBatchSize(shapes_list_in, dense_list_in);
441    for (int i = 0; i < size; i++) {
442      OP_REQUIRES(
443          context, TensorShapeUtils::IsVector(shapes_list_in[i].shape()),
444          errors::InvalidArgument(
445              "Input shapes should be a std::vector but received shape ",
446              shapes_list_in[i].shape().DebugString(), " at position ", i));
447
448      OP_REQUIRES(
449          context, shapes_list_in[i].vec<int64>().size() == 2,
450          errors::InvalidArgument("shape should imply a 2D tensor, but got ",
451                                  shapes_list_in[i].shape().DebugString(),
452                                  " at position ", i));
453      OP_REQUIRES(context, shapes_list_in[i].vec<int64>()(0) == batch_size,
454                  errors::InvalidArgument(
455                      "Expected batch size ", batch_size, " got ",
456                      shapes_list_in[i].vec<int64>()(0), " at position ", i));
457    }
458
459    // Validates dense_list_in OpInputList
460    for (int i = 0; i < dense_list_in.size(); ++i) {
461      OP_REQUIRES(
462          context, TensorShapeUtils::IsMatrix(dense_list_in[i].shape()),
463          errors::InvalidArgument(
464              "Dense inputs should be a matrix but received shape ",
465              indices_list_in[i].shape().DebugString(), " at position ", i));
466      OP_REQUIRES(context, dense_list_in[i].dim_size(0) == batch_size,
467                  errors::InvalidArgument("Expected batch size ", batch_size,
468                                          " got ", dense_list_in[i].dim_size(0),
469                                          " at dense tensor ", i));
470    }
471  }
472
473  // Calculate the batch size from either the shapes input or the dense input.
474  int64 CalculateBatchSize(const OpInputList& shapes_list_in,
475                           const OpInputList& dense_list_in) {
476    if (shapes_list_in.size() > 0) {
477      return shapes_list_in[0].vec<int64>()(0);
478    }
479
480    if (dense_list_in.size() > 0) {
481      return dense_list_in[0].dim_size(0);
482    }
483
484    return 0;
485  }
486
487  // Generate the columns given the sparse and dense inputs.
488  std::vector<std::unique_ptr<ColumnInterface<InternalType>>>
489  GenerateColumnsFromInput(const OpInputList& indices_list_in,
490                           const OpInputList& values_list_in,
491                           const OpInputList& shapes_list_in,
492                           const OpInputList& dense_list_in) {
493    std::vector<std::unique_ptr<ColumnInterface<InternalType>>> columns;
494    const int64 batch_size = CalculateBatchSize(shapes_list_in, dense_list_in);
495    const int64 number_of_columns = shapes_list_in.size();
496
497    std::vector<std::vector<int64>> feature_counts(number_of_columns,
498                                                   std::vector<int64>());
499    std::vector<std::vector<int64>> feature_start_indices(number_of_columns,
500                                                          std::vector<int64>());
501
502    ExtractFeatureData(indices_list_in, batch_size, &feature_counts,
503                       &feature_start_indices);
504
505    columns.reserve(values_list_in.size());
506    for (int i = 0; i < values_list_in.size(); ++i) {
507      columns.emplace_back(new SparseTensorColumn<InternalType>(
508          values_list_in[i], std::move(feature_counts[i]),
509          std::move(feature_start_indices[i])));
510    }
511    for (int i = 0; i < dense_list_in.size(); ++i) {
512      columns.emplace_back(
513          new DenseTensorColumn<InternalType>(dense_list_in[i]));
514    }
515
516    return columns;
517  }
518
519  // Extracts data about the features and populates feature data.
520  void ExtractFeatureData(
521      const OpInputList& indices_list_in, int64 batch_size,
522      std::vector<std::vector<int64>>* feature_counts,
523      std::vector<std::vector<int64>>* feature_start_indices) {
524    gtl::InlinedVector<int64, 8> current_row(indices_list_in.size(), 0);
525    for (int b = 0; b < batch_size; b++) {
526      for (int i = 0; i < indices_list_in.size(); i++) {
527        const auto indices = indices_list_in[i].matrix<int64>();
528        int64 feature_count = 0;
529        int64 start_index = current_row[i];
530        // Loops until we reach next batch index for current feature column.
531        while (current_row[i] < indices_list_in[i].dim_size(0) &&
532               indices(current_row[i], 0) == b) {
533          feature_count++;
534          current_row[i]++;
535        }
536        (*feature_counts)[i].push_back(feature_count);
537        (*feature_start_indices)[i].push_back(start_index);
538      }
539    }
540  }
541
542  // Allocates output tensors with proper size and sets the shape tensor of
543  // the output SparseTensor.
544  // It also output_start_indices which contains the start indices for each
545  // input in the output SparseTensor.
546  void CreateOutputTensors(
547      const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>&
548          columns,
549      int64 batch_size, OpKernelContext* context, Tensor** indices_out,
550      Tensor** values_out, Tensor** shape_out,
551      std::vector<int64>* output_start_indices) {
552    // Calculates dimensions for output tensors.
553    int64 cross_count_total = 0;
554    int64 max_cross_count = 0;
555    for (int64 b = 0; b < batch_size; b++) {
556      // For each input, sets starting indices in output SparseTensor
557      (*output_start_indices)[b] = cross_count_total;
558      const auto cross_count = CrossCountByBatchIndex(columns, b);
559      max_cross_count = std::max(max_cross_count, cross_count);
560      cross_count_total += cross_count;
561    }
562
563    // Allocates tensors.
564    OP_REQUIRES_OK(context,
565                   context->allocate_output(
566                       0, TensorShape({cross_count_total, 2}), indices_out));
567    OP_REQUIRES_OK(context,
568                   context->allocate_output(1, TensorShape({cross_count_total}),
569                                            values_out));
570    OP_REQUIRES_OK(context,
571                   context->allocate_output(2, TensorShape({2}), shape_out));
572
573    // Sets shape.
574    auto shape_vec = (*shape_out)->vec<int64>();
575    shape_vec(0) = batch_size;
576    shape_vec(1) = max_cross_count;
577  }
578
579  // Returns number of crosses for a given batch_index
580  int64 CrossCountByBatchIndex(
581      const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>&
582          columns,
583      int batch_index) {
584    int64 cross_count = 1;
585    for (size_t i = 0; i < columns.size(); i++) {
586      const auto feature_count = columns[i]->FeatureCount(batch_index);
587      // If one column is missing any feature, there won't be any cross.
588      if (feature_count == 0) {
589        return 0;
590      }
591      cross_count *= feature_count;
592    }
593    return cross_count;
594  }
595  int64 num_buckets_;
596  uint64 hash_key_;
597};
598
599REGISTER_KERNEL_BUILDER(Name("SparseFeatureCross")
600                            .Device(DEVICE_CPU)
601                            .TypeConstraint<string>("out_type")
602                            .TypeConstraint<string>("internal_type"),
603                        SparseFeatureCrossOp<false, StringPiece, false>);
604
605REGISTER_KERNEL_BUILDER(Name("SparseFeatureCross")
606                            .Device(DEVICE_CPU)
607                            .TypeConstraint<string>("out_type")
608                            .TypeConstraint<int64>("internal_type"),
609                        SparseFeatureCrossOp<false, string, false>);
610
611REGISTER_KERNEL_BUILDER(Name("SparseFeatureCross")
612                            .Device(DEVICE_CPU)
613                            .TypeConstraint<int64>("out_type")
614                            .TypeConstraint<string>("internal_type"),
615                        SparseFeatureCrossOp<true, int64, false>);
616
617REGISTER_KERNEL_BUILDER(Name("SparseFeatureCross")
618                            .Device(DEVICE_CPU)
619                            .TypeConstraint<int64>("out_type")
620                            .TypeConstraint<int64>("internal_type"),
621                        SparseFeatureCrossOp<true, int64, false>);
622
623// The following builders enable FingerprintCat64 concatenation for the
624// crosses features.
625REGISTER_KERNEL_BUILDER(Name("SparseFeatureCrossV2")
626                            .Device(DEVICE_CPU)
627                            .TypeConstraint<string>("out_type")
628                            .TypeConstraint<string>("internal_type"),
629                        SparseFeatureCrossOp<false, StringPiece, true>);
630
631REGISTER_KERNEL_BUILDER(Name("SparseFeatureCrossV2")
632                            .Device(DEVICE_CPU)
633                            .TypeConstraint<string>("out_type")
634                            .TypeConstraint<int64>("internal_type"),
635                        SparseFeatureCrossOp<false, string, true>);
636
637REGISTER_KERNEL_BUILDER(Name("SparseFeatureCrossV2")
638                            .Device(DEVICE_CPU)
639                            .TypeConstraint<int64>("out_type")
640                            .TypeConstraint<string>("internal_type"),
641                        SparseFeatureCrossOp<true, int64, true>);
642
643REGISTER_KERNEL_BUILDER(Name("SparseFeatureCrossV2")
644                            .Device(DEVICE_CPU)
645                            .TypeConstraint<int64>("out_type")
646                            .TypeConstraint<int64>("internal_type"),
647                        SparseFeatureCrossOp<true, int64, true>);
648
649}  // namespace tensorflow
650