1fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower
3fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlowerLicensed under the Apache License, Version 2.0 (the "License");
4fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFloweryou may not use this file except in compliance with the License.
5fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlowerYou may obtain a copy of the License at
6fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower
7fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    http://www.apache.org/licenses/LICENSE-2.0
8fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower
9fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlowerUnless required by applicable law or agreed to in writing, software
10fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlowerdistributed under the License is distributed on an "AS IS" BASIS,
11fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlowerWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlowerSee the License for the specific language governing permissions and
13fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlowerlimitations under the License.
14fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower==============================================================================*/
15fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower
16fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower// Contains OP to generate sparse crosses.
17fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower#include <assert.h>
18fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower#include <limits>
19fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower#include <string>
20fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower#include <vector>
21fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower
22fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
23fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower#include "tensorflow/core/framework/kernel_def_builder.h"
24fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower#include "tensorflow/core/framework/op_def_builder.h"
25fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower#include "tensorflow/core/framework/op_kernel.h"
26fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower#include "tensorflow/core/framework/tensor.h"
27fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower#include "tensorflow/core/framework/tensor_shape.h"
28fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower#include "tensorflow/core/framework/types.h"
29fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower#include "tensorflow/core/lib/core/stringpiece.h"
30e85d3df92deb9d717befdf173966a2913ac2aea0Geoffrey Irving#include "tensorflow/core/lib/strings/str_util.h"
31fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower#include "tensorflow/core/platform/fingerprint.h"
32fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower#include "tensorflow/core/util/work_sharder.h"
33fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower
34fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlowernamespace tensorflow {
35fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower
36fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlowernamespace {
37fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower// An interface that represents a column with batches.
38fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlowertemplate <typename InternalType>
39fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlowerclass ColumnInterface {
40fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower public:
41fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower  // Returns the number of features in the specified batch.
42fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower  virtual int64 FeatureCount(int64 batch) const = 0;
43fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower
44fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower  // Returns the fingerprint of nth feature from the specified batch.
45ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlower  virtual InternalType Feature(int64 batch, int64 n) const = 0;
46fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower
47fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower  virtual ~ColumnInterface() {}
48fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower};
49fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower
50fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower// A column that is backed by a sparse tensor.
51fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlowertemplate <typename InternalType>
52fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlowerclass SparseTensorColumn : public ColumnInterface<InternalType> {
53fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower public:
54fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower  SparseTensorColumn(const Tensor& values, std::vector<int64> feature_counts,
55fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower                     std::vector<int64> feature_start_indices)
56fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower      : values_(values),
57fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower        feature_counts_(std::move(feature_counts)),
58fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower        feature_start_indices_(std::move(feature_start_indices)) {
59fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    CHECK_EQ(feature_counts_.size(), feature_start_indices_.size());
60fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower  }
61fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower
62fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower  int64 FeatureCount(int64 batch) const override {
63fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    return feature_counts_[batch];
64fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower  }
65fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower
66ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlower  InternalType Feature(int64 batch, int64 n) const override;
67fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower
68fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower  ~SparseTensorColumn() override {}
69fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower
70fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower private:
71fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower  const Tensor& values_;
72fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower  std::vector<int64> feature_counts_;
73fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower  std::vector<int64> feature_start_indices_;
74fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower};
75fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower
76ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlower// InternalType is int64 only when using HashCrosser.
77ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlowertemplate <>
78ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlowerint64 SparseTensorColumn<int64>::Feature(int64 batch, int64 n) const {
79ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlower  const int64 start = feature_start_indices_[batch];
80ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlower  if (DT_STRING == values_.dtype())
81ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlower    return Fingerprint64(values_.vec<string>().data()[start + n]);
82ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlower  return values_.vec<int64>().data()[start + n];
83ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlower}
84ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlower
85ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlower// InternalType is string or StringPiece when using StringCrosser.
86ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlowertemplate <>
87ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlowerstring SparseTensorColumn<string>::Feature(int64 batch, int64 n) const {
88ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlower  const int64 start = feature_start_indices_[batch];
89ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlower  if (DT_STRING == values_.dtype())
90ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlower    return values_.vec<string>().data()[start + n];
91ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlower  return std::to_string(values_.vec<int64>().data()[start + n]);
92ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlower}
93ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlower
94ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlowertemplate <>
95ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlowerStringPiece SparseTensorColumn<StringPiece>::Feature(int64 batch,
96ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlower                                                     int64 n) const {
97ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlower  const int64 start = feature_start_indices_[batch];
98ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlower  return values_.vec<string>().data()[start + n];
99ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlower}
100ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlower
101fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower// A column that is backed by a dense tensor.
102fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlowertemplate <typename InternalType>
103fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlowerclass DenseTensorColumn : public ColumnInterface<InternalType> {
104fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower public:
105fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower  explicit DenseTensorColumn(const Tensor& tensor) : tensor_(tensor) {}
106fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower
107fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower  int64 FeatureCount(int64 batch) const override { return tensor_.dim_size(1); }
108fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower
109ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlower  InternalType Feature(int64 batch, int64 n) const override;
110fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower
111fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower  ~DenseTensorColumn() override {}
112fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower
113fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower private:
114fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower  const Tensor& tensor_;
115fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower};
116fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower
117ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlower// InternalType is int64 only when using HashCrosser.
118ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlowertemplate <>
119ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlowerint64 DenseTensorColumn<int64>::Feature(int64 batch, int64 n) const {
120ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlower  if (DT_STRING == tensor_.dtype())
121ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlower    return Fingerprint64(tensor_.matrix<string>()(batch, n));
122ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlower  return tensor_.matrix<int64>()(batch, n);
123ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlower}
124ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlower
125ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlower// Internal type is string or StringPiece when using StringCrosser.
126ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlowertemplate <>
127ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlowerstring DenseTensorColumn<string>::Feature(int64 batch, int64 n) const {
128ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlower  if (DT_STRING == tensor_.dtype()) return tensor_.matrix<string>()(batch, n);
129ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlower  return std::to_string(tensor_.matrix<int64>()(batch, n));
130ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlower}
131ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlower
132ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlowertemplate <>
133ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlowerStringPiece DenseTensorColumn<StringPiece>::Feature(int64 batch,
134ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlower                                                    int64 n) const {
135ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlower  return tensor_.matrix<string>()(batch, n);
136ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlower}
137ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlower
138fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower// Updates Output tensors with sparse crosses.
139fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlowertemplate <typename OutType>
140fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlowerclass OutputUpdater {
141fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower public:
142fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower  OutputUpdater(const std::vector<int64>& output_start_indices,
143fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower                Tensor* indices_out, Tensor* values_out)
144fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower      : output_start_indices_(output_start_indices),
145fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower        indices_out_(indices_out),
146fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower        values_out_(values_out) {}
147fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower
148fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower  void Update(const int64 batch_index, const int64 cross_count,
149fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower              const OutType& cross) const {
150fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    const int64 output_index = output_start_indices_[batch_index] + cross_count;
151fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower
152fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    auto indices_matrix = indices_out_->matrix<int64>();
153fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    indices_matrix(output_index, 0) = batch_index;
154fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    indices_matrix(output_index, 1) = cross_count;
155fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower
156fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    auto value_vec = values_out_->vec<OutType>();
157fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    value_vec(output_index) = cross;
158fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower  }
159fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower
160fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower private:
161fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower  const std::vector<int64>& output_start_indices_;
162fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower  Tensor* indices_out_;
163fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower  Tensor* values_out_;
164fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower};
165fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower
166fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower// Generates the sparse crosses as concatenation of strings.
167fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlowertemplate <typename InternalType>
168fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlowerclass StringCrosser {
169fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower public:
170fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower  StringCrosser(const std::vector<
171fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower                    std::unique_ptr<ColumnInterface<InternalType>>>& columns,
172fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower                const int64 num_buckets_unused, const uint64 hash_key_unused)
173fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower      : columns_(columns) {}
174fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower
175fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower  string Generate(const int64 batch_index,
176fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower                  const std::vector<int>& permutation) const {
177fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    static const auto k_feature_separator = "_X_";
178fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower
179fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    gtl::InlinedVector<InternalType, 6> cross_vec(columns_.size());
180fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    for (int i = 0; i < permutation.size(); i++) {
181fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower      cross_vec[i] = columns_[i]->Feature(batch_index, permutation[i]);
182fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    }
183fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    // TODO(zakaria): this will copy the string twice, might effect
184fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    // performance.
185fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    return str_util::Join(cross_vec, k_feature_separator);
186fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower  }
187fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower
188fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower private:
189fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower  const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>& columns_;
190fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower};
191fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower
192fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower// Generates the sparse crosses as nested hash to avoid string manipulations.
193fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlowerclass HashCrosser {
194fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower public:
195fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower  HashCrosser(
196fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower      const std::vector<std::unique_ptr<ColumnInterface<int64>>>& columns,
197fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower      const int64 num_buckets, const uint64 hash_key)
198fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower      : columns_(columns), num_buckets_(num_buckets), hash_key_(hash_key) {}
199fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower
200fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower  int64 Generate(const int64 batch_index,
201fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower                 const std::vector<int>& permutation) const {
202fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    // Do the fingerprint concatenation on uint64.
203fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    uint64 hashed_output = hash_key_;
204fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    for (size_t i = 0; i < permutation.size(); ++i) {
205fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower      uint64 hash_i = columns_[i]->Feature(batch_index, permutation[i]);
206fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower      hashed_output = FingerprintCat64(hashed_output, hash_i);
207fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    }
208fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    // The return value is int64 based on the number of buckets.
209fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    if (num_buckets_ > 0) {
210fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower      return hashed_output % num_buckets_;
211fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    } else {
212fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower      // To prevent negative output we take modulo to max int64.
213fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower      return hashed_output % std::numeric_limits<int64>::max();
214fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    }
215fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower  }
216fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower
217fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower private:
218fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower  const std::vector<std::unique_ptr<ColumnInterface<int64>>>& columns_;
219fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower  const int64 num_buckets_;
220fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower  const uint64 hash_key_;
221fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower};
222fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower
223fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower// ProductIterator generates cartesian products based on indices.
224fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlowertemplate <typename InternalType>
225fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlowerclass ProductIterator {
226fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower public:
227fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower  explicit ProductIterator(
228fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower      const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>&
229fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower          columns,
230fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower      int64 batch_index)
231fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower      : columns_(columns), batch_index_(batch_index) {
232fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    next_permutation_.resize(columns_.size(), 0);
233fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    // Sets has_next_ to false if any feature column has 0 features.
234fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    has_next_ = true;
235fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    for (int i = 0; i < columns_.size(); i++) {
236fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower      if (columns_[i]->FeatureCount(batch_index_) == 0) {
237fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower        has_next_ = false;
238fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower        break;
239fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower      }
240fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    }
241fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower  }
242fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower
243fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower  std::vector<int> Next() {
244fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    std::vector<int> permutation(next_permutation_);
245fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower
246fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    // Generates next permutation, if available.
247fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    bool carry = true;
248fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    for (int i = next_permutation_.size() - 1; i >= 0; i--) {
249fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower      if (carry) {
250fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower        next_permutation_[i] = next_permutation_[i] + 1;
251fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower      }
252fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower      if (next_permutation_[i] == columns_[i]->FeatureCount(batch_index_)) {
253fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower        next_permutation_[i] = 0;
254fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower      } else {
255fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower        carry = false;
256fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower        break;
257fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower      }
258fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    }
259fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    has_next_ = !carry;
260fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    return permutation;
261fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower  }
262fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower
263fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower  bool HasNext() { return has_next_; }
264fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower
265fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower private:
266fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower  bool has_next_;
267fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower  const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>& columns_;
268fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower  const int64 batch_index_;
269fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower  std::vector<int> next_permutation_;
270fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower};
271fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower
272fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlowertemplate <bool HASHED_OUTPUT, typename InternalType>
273fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlowerstruct CrossTraits;
274fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower
275fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlowertemplate <typename InternalType>
276fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlowerstruct CrossTraits<false, InternalType> {
277fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower  typedef StringCrosser<InternalType> Crosser;
278fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower  typedef OutputUpdater<string> Updater;
279fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower};
280fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower
281fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlowertemplate <>
282fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlowerstruct CrossTraits<true, int64> {
283fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower  typedef HashCrosser Crosser;
284fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower  typedef OutputUpdater<int64> Updater;
285fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower};
286fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower}  // namespace
287fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower
288fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlowertemplate <bool HASHED_OUTPUT, typename InternalType>
289fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlowerclass SparseCrossOp : public OpKernel {
290fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower public:
291982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen  explicit SparseCrossOp(OpKernelConstruction* context) : OpKernel(context) {
292fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    OP_REQUIRES_OK(context, context->GetAttr("num_buckets", &num_buckets_));
293fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    // Read signed_hash_key_ as int64 since uint64 attributes are not
294fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    // supported by REGISTER_OP.
295fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    int64 signed_hash_key_;
296fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    OP_REQUIRES_OK(context, context->GetAttr("hash_key", &signed_hash_key_));
297fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    hash_key_ = static_cast<uint64>(signed_hash_key_);
298fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower  }
299fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower
300fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower  void Compute(OpKernelContext* context) override {
301fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    OpInputList indices_list_in;
302fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    OP_REQUIRES_OK(context, context->input_list("indices", &indices_list_in));
303fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    OpInputList values_list_in;
304fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    OP_REQUIRES_OK(context, context->input_list("values", &values_list_in));
305fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    OpInputList shapes_list_in;
306fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    OP_REQUIRES_OK(context, context->input_list("shapes", &shapes_list_in));
307fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    OpInputList dense_list_in;
308fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    OP_REQUIRES_OK(context,
309fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower                   context->input_list("dense_inputs", &dense_list_in));
310fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower
311fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    ValidateInput(context, indices_list_in, values_list_in, shapes_list_in,
312fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower                  dense_list_in);
313fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower
314fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    std::vector<std::unique_ptr<ColumnInterface<InternalType>>> columns =
315fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower        GenerateColumnsFromInput(indices_list_in, values_list_in,
316fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower                                 shapes_list_in, dense_list_in);
317fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower
318982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen    typename CrossTraits<HASHED_OUTPUT, InternalType>::Crosser crosser(
319982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen        columns, num_buckets_, hash_key_);
320fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    Tensor* indices_out;
321fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    Tensor* values_out;
322fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    Tensor* shape_out;
323fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    const int64 batch_size = CalculateBatchSize(shapes_list_in, dense_list_in);
324fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    std::vector<int64> output_start_indices(batch_size);
325fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    CreateOutputTensors(columns, batch_size, context, &indices_out, &values_out,
326fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower                        &shape_out, &output_start_indices);
327fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower
328982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen    typename CrossTraits<HASHED_OUTPUT, InternalType>::Updater updater(
329982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen        output_start_indices, indices_out, values_out);
330fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    auto do_work = [this, &columns, crosser, updater](int64 begin, int64 end) {
331fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower      for (int b = begin; b < end; b++) {
332fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower        ProductIterator<InternalType> product_iterator(columns, b);
333fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower        int64 cross_count = 0;
334fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower        while (product_iterator.HasNext()) {
335fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower          const auto permutation = product_iterator.Next();
336fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower          updater.Update(b, cross_count, crosser.Generate(b, permutation));
337fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower          cross_count++;
338fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower        }
339fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower      }
340fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    };
341fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower
342fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    auto* worker_threads = context->device()->tensorflow_cpu_worker_threads();
343fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    // TODO(zakaria): optimize kCostPerUnit
344fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    const int kCostPerUnit = 5000 * indices_list_in.size();
345fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    Shard(worker_threads->num_threads, worker_threads->workers, batch_size,
346fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower          kCostPerUnit, do_work);
347fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower  }
348fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower
349fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower private:
350fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower  // Validates input tensors.
351fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower  void ValidateInput(OpKernelContext* context,
352fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower                     const OpInputList& indices_list_in,
353fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower                     const OpInputList& values_list_in,
354fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower                     const OpInputList& shapes_list_in,
355fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower                     const OpInputList& dense_list_in) {
356fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    const auto size = indices_list_in.size();
357fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    // Validates indices_list_in OpInputList.
358fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    for (int i = 0; i < size; i++) {
359fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower      OP_REQUIRES(
360fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower          context, TensorShapeUtils::IsMatrix(indices_list_in[i].shape()),
361fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower          errors::InvalidArgument(
362fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower              "Input indices should be a matrix but received shape ",
363fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower              indices_list_in[i].shape().DebugString(), " at position ", i));
364fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower      OP_REQUIRES(
365fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower          context, indices_list_in[i].shape().dim_size(1) == 2,
366fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower          errors::InvalidArgument("Expected D2 of index to be 2 got ",
367fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower                                  indices_list_in[i].shape().dim_size(1),
368fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower                                  " at position ", i));
369fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    }
370fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower
371fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    // Validates values_list_in OpInputList.
372fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    OP_REQUIRES(
373fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower        context, values_list_in.size() == size,
374fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower        errors::InvalidArgument("Expected ", size, " input values, got ",
375fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower                                values_list_in.size()));
376fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    for (int i = 0; i < size; i++) {
377fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower      OP_REQUIRES(
378fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower          context, TensorShapeUtils::IsVector(values_list_in[i].shape()),
379fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower          errors::InvalidArgument(
380fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower              "Input values should be a std::vector but received shape ",
381fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower              values_list_in[i].shape().DebugString(), " at position ", i));
382fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower      OP_REQUIRES(
383982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen          context,
384982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen          indices_list_in[i].shape().dim_size(0) ==
385982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen              values_list_in[i].shape().dim_size(0),
386fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower          errors::InvalidArgument(
387fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower              "Expected size of values to be ",
388fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower              indices_list_in[i].shape().dim_size(0), " got ",
389fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower              values_list_in[i].shape().dim_size(0), " at position ", i));
390fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    }
391fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower
392fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    // Validates shapes_list_in OpInputList
393fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    OP_REQUIRES(
394fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower        context, shapes_list_in.size() == size,
395fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower        errors::InvalidArgument("Expected ", size, " input shapes, got ",
396fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower                                shapes_list_in.size()));
397fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    const auto batch_size = CalculateBatchSize(shapes_list_in, dense_list_in);
398fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    for (int i = 0; i < size; i++) {
399fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower      OP_REQUIRES(
400fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower          context, TensorShapeUtils::IsVector(shapes_list_in[i].shape()),
401fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower          errors::InvalidArgument(
402fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower              "Input shapes should be a std::vector but received shape ",
403fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower              shapes_list_in[i].shape().DebugString(), " at position ", i));
404fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower
405fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower      OP_REQUIRES(
406fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower          context, shapes_list_in[i].vec<int64>().size() == 2,
407fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower          errors::InvalidArgument("shape should imply a 2D tensor, but got ",
408fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower                                  shapes_list_in[i].shape().DebugString(),
409fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower                                  " at position ", i));
410fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower      OP_REQUIRES(context, shapes_list_in[i].vec<int64>()(0) == batch_size,
411fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower                  errors::InvalidArgument(
412fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower                      "Expected batch size ", batch_size, " got ",
413fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower                      shapes_list_in[i].vec<int64>()(0), " at position ", i));
414fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    }
415fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower
416fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    // Validates dense_list_in OpInputList
417fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    for (int i = 0; i < dense_list_in.size(); ++i) {
418fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower      OP_REQUIRES(
419fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower          context, TensorShapeUtils::IsMatrix(dense_list_in[i].shape()),
420fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower          errors::InvalidArgument(
421fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower              "Dense inputs should be a matrix but received shape ",
422fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower              indices_list_in[i].shape().DebugString(), " at position ", i));
423fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower      OP_REQUIRES(context, dense_list_in[i].dim_size(0) == batch_size,
424fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower                  errors::InvalidArgument("Expected batch size ", batch_size,
425fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower                                          " got ", dense_list_in[i].dim_size(0),
426fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower                                          " at dense tensor ", i));
427fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    }
428fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower  }
429fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower
430fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower  // Calculate the batch size from either the shapes input or the dense input.
431fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower  int64 CalculateBatchSize(const OpInputList& shapes_list_in,
432fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower                           const OpInputList& dense_list_in) {
433fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    if (shapes_list_in.size() > 0) {
434fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower      return shapes_list_in[0].vec<int64>()(0);
435fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    }
436fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower
437fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    if (dense_list_in.size() > 0) {
438fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower      return dense_list_in[0].dim_size(0);
439fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    }
440fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower
441fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    return 0;
442fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower  }
443fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower
444fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower  // Generate the columns given the sparse and dense inputs.
445fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower  std::vector<std::unique_ptr<ColumnInterface<InternalType>>>
446fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower  GenerateColumnsFromInput(const OpInputList& indices_list_in,
447fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower                           const OpInputList& values_list_in,
448fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower                           const OpInputList& shapes_list_in,
449fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower                           const OpInputList& dense_list_in) {
450fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    std::vector<std::unique_ptr<ColumnInterface<InternalType>>> columns;
451fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    const int64 batch_size = CalculateBatchSize(shapes_list_in, dense_list_in);
452fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    const int64 number_of_columns = shapes_list_in.size();
453fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower
454fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    std::vector<std::vector<int64>> feature_counts(number_of_columns,
455fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower                                                   std::vector<int64>());
456fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    std::vector<std::vector<int64>> feature_start_indices(number_of_columns,
457fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower                                                          std::vector<int64>());
458fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower
459fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    ExtractFeatureData(indices_list_in, batch_size, &feature_counts,
460fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower                       &feature_start_indices);
461fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower
462eb10a4c494d95e7c17ddc44ef35197d08f2f6b33A. Unique TensorFlower    columns.reserve(values_list_in.size());
463fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    for (int i = 0; i < values_list_in.size(); ++i) {
464fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower      columns.emplace_back(new SparseTensorColumn<InternalType>(
465fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower          values_list_in[i], std::move(feature_counts[i]),
466fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower          std::move(feature_start_indices[i])));
467fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    }
468fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    for (int i = 0; i < dense_list_in.size(); ++i) {
469fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower      columns.emplace_back(
470fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower          new DenseTensorColumn<InternalType>(dense_list_in[i]));
471fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    }
472fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower
473fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    return columns;
474fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower  }
475fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower
476fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower  // Extracts data about the features and populates feature data.
477fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower  void ExtractFeatureData(
478fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower      const OpInputList& indices_list_in, int64 batch_size,
479fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower      std::vector<std::vector<int64>>* feature_counts,
480fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower      std::vector<std::vector<int64>>* feature_start_indices) {
481fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    gtl::InlinedVector<int64, 8> current_row(indices_list_in.size(), 0);
482fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    for (int b = 0; b < batch_size; b++) {
483fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower      for (int i = 0; i < indices_list_in.size(); i++) {
484fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower        const auto indices = indices_list_in[i].matrix<int64>();
485fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower        int64 feature_count = 0;
486fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower        int64 start_index = current_row[i];
487fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower        // Loops until we reach next batch index for current feature column.
488fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower        while (current_row[i] < indices_list_in[i].dim_size(0) &&
489fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower               indices(current_row[i], 0) == b) {
490fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower          feature_count++;
491fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower          current_row[i]++;
492fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower        }
493fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower        (*feature_counts)[i].push_back(feature_count);
494fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower        (*feature_start_indices)[i].push_back(start_index);
495fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower      }
496fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    }
497fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower  }
498fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower
499fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower  // Allocates output tensors with proper size and sets the shape tensor of
500fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower  // the output SparseTensor.
501fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower  // It also output_start_indices which contains the start indices for each
502fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower  // input in the output SparseTensor.
503fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower  void CreateOutputTensors(
504fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower      const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>&
505fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower          columns,
506fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower      int64 batch_size, OpKernelContext* context, Tensor** indices_out,
507fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower      Tensor** values_out, Tensor** shape_out,
508fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower      std::vector<int64>* output_start_indices) {
509fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    // Calculates dimensions for output tensors.
510fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    int64 cross_count_total = 0;
511fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    int64 max_cross_count = 0;
512fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    for (int64 b = 0; b < batch_size; b++) {
513fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower      // For each input, sets starting indices in output SparseTensor
514fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower      (*output_start_indices)[b] = cross_count_total;
515fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower      const auto cross_count = CrossCountByBatchIndex(columns, b);
516fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower      max_cross_count = std::max(max_cross_count, cross_count);
517fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower      cross_count_total += cross_count;
518fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    }
519fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower
520fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    // Allocates tensors.
521fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    OP_REQUIRES_OK(context,
522fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower                   context->allocate_output(
523fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower                       0, TensorShape({cross_count_total, 2}), indices_out));
524fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    OP_REQUIRES_OK(context,
525fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower                   context->allocate_output(1, TensorShape({cross_count_total}),
526fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower                                            values_out));
527fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    OP_REQUIRES_OK(context,
528fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower                   context->allocate_output(2, TensorShape({2}), shape_out));
529fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower
530fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    // Sets shape.
531fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    auto shape_vec = (*shape_out)->vec<int64>();
532fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    shape_vec(0) = batch_size;
533fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    shape_vec(1) = max_cross_count;
534fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower  }
535fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower
536fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower  // Returns number of crosses for a given batch_index
537fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower  int64 CrossCountByBatchIndex(
538fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower      const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>&
539fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower          columns,
540fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower      int batch_index) {
541fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    int64 cross_count = 1;
542fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    for (int i = 0; i < columns.size(); i++) {
543fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower      const auto feature_count = columns[i]->FeatureCount(batch_index);
544fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower      // If one column is missing any feature, there won't be any cross.
545fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower      if (feature_count == 0) {
546fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower        return 0;
547fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower      }
548fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower      cross_count *= feature_count;
549fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    }
550fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower    return cross_count;
551fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower  }
552fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower  int64 num_buckets_;
553fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower  uint64 hash_key_;
554fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower};
555fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower
556fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlowerREGISTER_KERNEL_BUILDER(Name("SparseCross")
557fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower                            .Device(DEVICE_CPU)
558fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower                            .TypeConstraint<string>("out_type")
559fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower                            .TypeConstraint<string>("internal_type"),
560fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower                        SparseCrossOp<false, StringPiece>);
561fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower
562fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlowerREGISTER_KERNEL_BUILDER(Name("SparseCross")
563fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower                            .Device(DEVICE_CPU)
564fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower                            .TypeConstraint<string>("out_type")
565fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower                            .TypeConstraint<int64>("internal_type"),
566fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower                        SparseCrossOp<false, string>);
567fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower
568fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlowerREGISTER_KERNEL_BUILDER(Name("SparseCross")
569fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower                            .Device(DEVICE_CPU)
570fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower                            .TypeConstraint<int64>("out_type")
571fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower                            .TypeConstraint<string>("internal_type"),
572fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower                        SparseCrossOp<true, int64>);
573fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower
574fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlowerREGISTER_KERNEL_BUILDER(Name("SparseCross")
575fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower                            .Device(DEVICE_CPU)
576fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower                            .TypeConstraint<int64>("out_type")
577fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower                            .TypeConstraint<int64>("internal_type"),
578fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower                        SparseCrossOp<true, int64>);
579fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower
580fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower}  // namespace tensorflow
581