1fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower 3fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlowerLicensed under the Apache License, Version 2.0 (the "License"); 4fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFloweryou may not use this file except in compliance with the License. 5fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlowerYou may obtain a copy of the License at 6fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower 7fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower http://www.apache.org/licenses/LICENSE-2.0 8fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower 9fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlowerUnless required by applicable law or agreed to in writing, software 10fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlowerdistributed under the License is distributed on an "AS IS" BASIS, 11fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlowerWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlowerSee the License for the specific language governing permissions and 13fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlowerlimitations under the License. 14fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower==============================================================================*/ 15fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower 16fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower// Contains OP to generate sparse crosses. 17fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower#include <assert.h> 18fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower#include <limits> 19fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower#include <string> 20fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower#include <vector> 21fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower 22fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 23fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower#include "tensorflow/core/framework/kernel_def_builder.h" 24fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower#include "tensorflow/core/framework/op_def_builder.h" 25fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower#include "tensorflow/core/framework/op_kernel.h" 26fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower#include "tensorflow/core/framework/tensor.h" 27fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower#include "tensorflow/core/framework/tensor_shape.h" 28fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower#include "tensorflow/core/framework/types.h" 29fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower#include "tensorflow/core/lib/core/stringpiece.h" 30e85d3df92deb9d717befdf173966a2913ac2aea0Geoffrey Irving#include "tensorflow/core/lib/strings/str_util.h" 31fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower#include "tensorflow/core/platform/fingerprint.h" 32fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower#include "tensorflow/core/util/work_sharder.h" 33fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower 34fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlowernamespace tensorflow { 35fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower 36fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlowernamespace { 37fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower// An interface that represents a column with batches. 38fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlowertemplate <typename InternalType> 39fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlowerclass ColumnInterface { 40fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower public: 41fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower // Returns the number of features in the specified batch. 42fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower virtual int64 FeatureCount(int64 batch) const = 0; 43fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower 44fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower // Returns the fingerprint of nth feature from the specified batch. 45ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlower virtual InternalType Feature(int64 batch, int64 n) const = 0; 46fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower 47fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower virtual ~ColumnInterface() {} 48fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower}; 49fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower 50fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower// A column that is backed by a sparse tensor. 51fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlowertemplate <typename InternalType> 52fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlowerclass SparseTensorColumn : public ColumnInterface<InternalType> { 53fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower public: 54fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower SparseTensorColumn(const Tensor& values, std::vector<int64> feature_counts, 55fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower std::vector<int64> feature_start_indices) 56fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower : values_(values), 57fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower feature_counts_(std::move(feature_counts)), 58fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower feature_start_indices_(std::move(feature_start_indices)) { 59fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower CHECK_EQ(feature_counts_.size(), feature_start_indices_.size()); 60fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower } 61fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower 62fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower int64 FeatureCount(int64 batch) const override { 63fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower return feature_counts_[batch]; 64fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower } 65fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower 66ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlower InternalType Feature(int64 batch, int64 n) const override; 67fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower 68fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower ~SparseTensorColumn() override {} 69fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower 70fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower private: 71fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower const Tensor& values_; 72fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower std::vector<int64> feature_counts_; 73fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower std::vector<int64> feature_start_indices_; 74fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower}; 75fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower 76ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlower// InternalType is int64 only when using HashCrosser. 77ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlowertemplate <> 78ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlowerint64 SparseTensorColumn<int64>::Feature(int64 batch, int64 n) const { 79ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlower const int64 start = feature_start_indices_[batch]; 80ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlower if (DT_STRING == values_.dtype()) 81ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlower return Fingerprint64(values_.vec<string>().data()[start + n]); 82ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlower return values_.vec<int64>().data()[start + n]; 83ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlower} 84ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlower 85ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlower// InternalType is string or StringPiece when using StringCrosser. 86ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlowertemplate <> 87ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlowerstring SparseTensorColumn<string>::Feature(int64 batch, int64 n) const { 88ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlower const int64 start = feature_start_indices_[batch]; 89ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlower if (DT_STRING == values_.dtype()) 90ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlower return values_.vec<string>().data()[start + n]; 91ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlower return std::to_string(values_.vec<int64>().data()[start + n]); 92ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlower} 93ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlower 94ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlowertemplate <> 95ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlowerStringPiece SparseTensorColumn<StringPiece>::Feature(int64 batch, 96ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlower int64 n) const { 97ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlower const int64 start = feature_start_indices_[batch]; 98ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlower return values_.vec<string>().data()[start + n]; 99ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlower} 100ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlower 101fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower// A column that is backed by a dense tensor. 102fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlowertemplate <typename InternalType> 103fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlowerclass DenseTensorColumn : public ColumnInterface<InternalType> { 104fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower public: 105fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower explicit DenseTensorColumn(const Tensor& tensor) : tensor_(tensor) {} 106fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower 107fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower int64 FeatureCount(int64 batch) const override { return tensor_.dim_size(1); } 108fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower 109ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlower InternalType Feature(int64 batch, int64 n) const override; 110fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower 111fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower ~DenseTensorColumn() override {} 112fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower 113fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower private: 114fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower const Tensor& tensor_; 115fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower}; 116fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower 117ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlower// InternalType is int64 only when using HashCrosser. 118ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlowertemplate <> 119ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlowerint64 DenseTensorColumn<int64>::Feature(int64 batch, int64 n) const { 120ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlower if (DT_STRING == tensor_.dtype()) 121ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlower return Fingerprint64(tensor_.matrix<string>()(batch, n)); 122ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlower return tensor_.matrix<int64>()(batch, n); 123ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlower} 124ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlower 125ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlower// Internal type is string or StringPiece when using StringCrosser. 126ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlowertemplate <> 127ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlowerstring DenseTensorColumn<string>::Feature(int64 batch, int64 n) const { 128ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlower if (DT_STRING == tensor_.dtype()) return tensor_.matrix<string>()(batch, n); 129ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlower return std::to_string(tensor_.matrix<int64>()(batch, n)); 130ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlower} 131ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlower 132ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlowertemplate <> 133ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlowerStringPiece DenseTensorColumn<StringPiece>::Feature(int64 batch, 134ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlower int64 n) const { 135ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlower return tensor_.matrix<string>()(batch, n); 136ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlower} 137ba656b261141554f33b96c655e3a0c76eb0d837dA. Unique TensorFlower 138fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower// Updates Output tensors with sparse crosses. 139fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlowertemplate <typename OutType> 140fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlowerclass OutputUpdater { 141fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower public: 142fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower OutputUpdater(const std::vector<int64>& output_start_indices, 143fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower Tensor* indices_out, Tensor* values_out) 144fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower : output_start_indices_(output_start_indices), 145fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower indices_out_(indices_out), 146fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower values_out_(values_out) {} 147fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower 148fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower void Update(const int64 batch_index, const int64 cross_count, 149fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower const OutType& cross) const { 150fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower const int64 output_index = output_start_indices_[batch_index] + cross_count; 151fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower 152fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower auto indices_matrix = indices_out_->matrix<int64>(); 153fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower indices_matrix(output_index, 0) = batch_index; 154fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower indices_matrix(output_index, 1) = cross_count; 155fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower 156fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower auto value_vec = values_out_->vec<OutType>(); 157fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower value_vec(output_index) = cross; 158fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower } 159fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower 160fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower private: 161fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower const std::vector<int64>& output_start_indices_; 162fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower Tensor* indices_out_; 163fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower Tensor* values_out_; 164fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower}; 165fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower 166fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower// Generates the sparse crosses as concatenation of strings. 167fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlowertemplate <typename InternalType> 168fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlowerclass StringCrosser { 169fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower public: 170fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower StringCrosser(const std::vector< 171fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower std::unique_ptr<ColumnInterface<InternalType>>>& columns, 172fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower const int64 num_buckets_unused, const uint64 hash_key_unused) 173fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower : columns_(columns) {} 174fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower 175fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower string Generate(const int64 batch_index, 176fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower const std::vector<int>& permutation) const { 177fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower static const auto k_feature_separator = "_X_"; 178fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower 179fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower gtl::InlinedVector<InternalType, 6> cross_vec(columns_.size()); 180fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower for (int i = 0; i < permutation.size(); i++) { 181fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower cross_vec[i] = columns_[i]->Feature(batch_index, permutation[i]); 182fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower } 183fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower // TODO(zakaria): this will copy the string twice, might effect 184fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower // performance. 185fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower return str_util::Join(cross_vec, k_feature_separator); 186fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower } 187fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower 188fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower private: 189fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>& columns_; 190fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower}; 191fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower 192fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower// Generates the sparse crosses as nested hash to avoid string manipulations. 193fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlowerclass HashCrosser { 194fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower public: 195fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower HashCrosser( 196fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower const std::vector<std::unique_ptr<ColumnInterface<int64>>>& columns, 197fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower const int64 num_buckets, const uint64 hash_key) 198fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower : columns_(columns), num_buckets_(num_buckets), hash_key_(hash_key) {} 199fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower 200fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower int64 Generate(const int64 batch_index, 201fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower const std::vector<int>& permutation) const { 202fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower // Do the fingerprint concatenation on uint64. 203fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower uint64 hashed_output = hash_key_; 204fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower for (size_t i = 0; i < permutation.size(); ++i) { 205fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower uint64 hash_i = columns_[i]->Feature(batch_index, permutation[i]); 206fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower hashed_output = FingerprintCat64(hashed_output, hash_i); 207fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower } 208fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower // The return value is int64 based on the number of buckets. 209fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower if (num_buckets_ > 0) { 210fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower return hashed_output % num_buckets_; 211fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower } else { 212fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower // To prevent negative output we take modulo to max int64. 213fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower return hashed_output % std::numeric_limits<int64>::max(); 214fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower } 215fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower } 216fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower 217fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower private: 218fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower const std::vector<std::unique_ptr<ColumnInterface<int64>>>& columns_; 219fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower const int64 num_buckets_; 220fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower const uint64 hash_key_; 221fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower}; 222fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower 223fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower// ProductIterator generates cartesian products based on indices. 224fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlowertemplate <typename InternalType> 225fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlowerclass ProductIterator { 226fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower public: 227fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower explicit ProductIterator( 228fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>& 229fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower columns, 230fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower int64 batch_index) 231fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower : columns_(columns), batch_index_(batch_index) { 232fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower next_permutation_.resize(columns_.size(), 0); 233fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower // Sets has_next_ to false if any feature column has 0 features. 234fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower has_next_ = true; 235fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower for (int i = 0; i < columns_.size(); i++) { 236fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower if (columns_[i]->FeatureCount(batch_index_) == 0) { 237fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower has_next_ = false; 238fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower break; 239fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower } 240fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower } 241fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower } 242fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower 243fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower std::vector<int> Next() { 244fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower std::vector<int> permutation(next_permutation_); 245fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower 246fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower // Generates next permutation, if available. 247fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower bool carry = true; 248fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower for (int i = next_permutation_.size() - 1; i >= 0; i--) { 249fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower if (carry) { 250fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower next_permutation_[i] = next_permutation_[i] + 1; 251fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower } 252fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower if (next_permutation_[i] == columns_[i]->FeatureCount(batch_index_)) { 253fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower next_permutation_[i] = 0; 254fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower } else { 255fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower carry = false; 256fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower break; 257fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower } 258fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower } 259fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower has_next_ = !carry; 260fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower return permutation; 261fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower } 262fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower 263fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower bool HasNext() { return has_next_; } 264fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower 265fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower private: 266fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower bool has_next_; 267fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>& columns_; 268fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower const int64 batch_index_; 269fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower std::vector<int> next_permutation_; 270fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower}; 271fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower 272fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlowertemplate <bool HASHED_OUTPUT, typename InternalType> 273fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlowerstruct CrossTraits; 274fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower 275fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlowertemplate <typename InternalType> 276fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlowerstruct CrossTraits<false, InternalType> { 277fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower typedef StringCrosser<InternalType> Crosser; 278fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower typedef OutputUpdater<string> Updater; 279fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower}; 280fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower 281fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlowertemplate <> 282fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlowerstruct CrossTraits<true, int64> { 283fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower typedef HashCrosser Crosser; 284fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower typedef OutputUpdater<int64> Updater; 285fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower}; 286fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower} // namespace 287fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower 288fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlowertemplate <bool HASHED_OUTPUT, typename InternalType> 289fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlowerclass SparseCrossOp : public OpKernel { 290fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower public: 291982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen explicit SparseCrossOp(OpKernelConstruction* context) : OpKernel(context) { 292fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower OP_REQUIRES_OK(context, context->GetAttr("num_buckets", &num_buckets_)); 293fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower // Read signed_hash_key_ as int64 since uint64 attributes are not 294fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower // supported by REGISTER_OP. 295fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower int64 signed_hash_key_; 296fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower OP_REQUIRES_OK(context, context->GetAttr("hash_key", &signed_hash_key_)); 297fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower hash_key_ = static_cast<uint64>(signed_hash_key_); 298fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower } 299fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower 300fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower void Compute(OpKernelContext* context) override { 301fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower OpInputList indices_list_in; 302fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower OP_REQUIRES_OK(context, context->input_list("indices", &indices_list_in)); 303fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower OpInputList values_list_in; 304fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower OP_REQUIRES_OK(context, context->input_list("values", &values_list_in)); 305fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower OpInputList shapes_list_in; 306fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower OP_REQUIRES_OK(context, context->input_list("shapes", &shapes_list_in)); 307fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower OpInputList dense_list_in; 308fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower OP_REQUIRES_OK(context, 309fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower context->input_list("dense_inputs", &dense_list_in)); 310fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower 311fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower ValidateInput(context, indices_list_in, values_list_in, shapes_list_in, 312fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower dense_list_in); 313fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower 314fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower std::vector<std::unique_ptr<ColumnInterface<InternalType>>> columns = 315fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower GenerateColumnsFromInput(indices_list_in, values_list_in, 316fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower shapes_list_in, dense_list_in); 317fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower 318982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen typename CrossTraits<HASHED_OUTPUT, InternalType>::Crosser crosser( 319982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen columns, num_buckets_, hash_key_); 320fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower Tensor* indices_out; 321fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower Tensor* values_out; 322fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower Tensor* shape_out; 323fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower const int64 batch_size = CalculateBatchSize(shapes_list_in, dense_list_in); 324fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower std::vector<int64> output_start_indices(batch_size); 325fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower CreateOutputTensors(columns, batch_size, context, &indices_out, &values_out, 326fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower &shape_out, &output_start_indices); 327fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower 328982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen typename CrossTraits<HASHED_OUTPUT, InternalType>::Updater updater( 329982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen output_start_indices, indices_out, values_out); 330fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower auto do_work = [this, &columns, crosser, updater](int64 begin, int64 end) { 331fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower for (int b = begin; b < end; b++) { 332fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower ProductIterator<InternalType> product_iterator(columns, b); 333fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower int64 cross_count = 0; 334fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower while (product_iterator.HasNext()) { 335fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower const auto permutation = product_iterator.Next(); 336fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower updater.Update(b, cross_count, crosser.Generate(b, permutation)); 337fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower cross_count++; 338fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower } 339fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower } 340fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower }; 341fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower 342fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower auto* worker_threads = context->device()->tensorflow_cpu_worker_threads(); 343fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower // TODO(zakaria): optimize kCostPerUnit 344fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower const int kCostPerUnit = 5000 * indices_list_in.size(); 345fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower Shard(worker_threads->num_threads, worker_threads->workers, batch_size, 346fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower kCostPerUnit, do_work); 347fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower } 348fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower 349fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower private: 350fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower // Validates input tensors. 351fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower void ValidateInput(OpKernelContext* context, 352fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower const OpInputList& indices_list_in, 353fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower const OpInputList& values_list_in, 354fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower const OpInputList& shapes_list_in, 355fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower const OpInputList& dense_list_in) { 356fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower const auto size = indices_list_in.size(); 357fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower // Validates indices_list_in OpInputList. 358fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower for (int i = 0; i < size; i++) { 359fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower OP_REQUIRES( 360fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower context, TensorShapeUtils::IsMatrix(indices_list_in[i].shape()), 361fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower errors::InvalidArgument( 362fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower "Input indices should be a matrix but received shape ", 363fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower indices_list_in[i].shape().DebugString(), " at position ", i)); 364fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower OP_REQUIRES( 365fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower context, indices_list_in[i].shape().dim_size(1) == 2, 366fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower errors::InvalidArgument("Expected D2 of index to be 2 got ", 367fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower indices_list_in[i].shape().dim_size(1), 368fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower " at position ", i)); 369fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower } 370fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower 371fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower // Validates values_list_in OpInputList. 372fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower OP_REQUIRES( 373fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower context, values_list_in.size() == size, 374fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower errors::InvalidArgument("Expected ", size, " input values, got ", 375fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower values_list_in.size())); 376fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower for (int i = 0; i < size; i++) { 377fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower OP_REQUIRES( 378fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower context, TensorShapeUtils::IsVector(values_list_in[i].shape()), 379fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower errors::InvalidArgument( 380fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower "Input values should be a std::vector but received shape ", 381fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower values_list_in[i].shape().DebugString(), " at position ", i)); 382fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower OP_REQUIRES( 383982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen context, 384982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen indices_list_in[i].shape().dim_size(0) == 385982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen values_list_in[i].shape().dim_size(0), 386fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower errors::InvalidArgument( 387fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower "Expected size of values to be ", 388fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower indices_list_in[i].shape().dim_size(0), " got ", 389fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower values_list_in[i].shape().dim_size(0), " at position ", i)); 390fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower } 391fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower 392fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower // Validates shapes_list_in OpInputList 393fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower OP_REQUIRES( 394fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower context, shapes_list_in.size() == size, 395fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower errors::InvalidArgument("Expected ", size, " input shapes, got ", 396fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower shapes_list_in.size())); 397fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower const auto batch_size = CalculateBatchSize(shapes_list_in, dense_list_in); 398fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower for (int i = 0; i < size; i++) { 399fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower OP_REQUIRES( 400fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower context, TensorShapeUtils::IsVector(shapes_list_in[i].shape()), 401fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower errors::InvalidArgument( 402fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower "Input shapes should be a std::vector but received shape ", 403fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower shapes_list_in[i].shape().DebugString(), " at position ", i)); 404fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower 405fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower OP_REQUIRES( 406fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower context, shapes_list_in[i].vec<int64>().size() == 2, 407fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower errors::InvalidArgument("shape should imply a 2D tensor, but got ", 408fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower shapes_list_in[i].shape().DebugString(), 409fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower " at position ", i)); 410fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower OP_REQUIRES(context, shapes_list_in[i].vec<int64>()(0) == batch_size, 411fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower errors::InvalidArgument( 412fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower "Expected batch size ", batch_size, " got ", 413fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower shapes_list_in[i].vec<int64>()(0), " at position ", i)); 414fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower } 415fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower 416fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower // Validates dense_list_in OpInputList 417fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower for (int i = 0; i < dense_list_in.size(); ++i) { 418fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower OP_REQUIRES( 419fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower context, TensorShapeUtils::IsMatrix(dense_list_in[i].shape()), 420fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower errors::InvalidArgument( 421fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower "Dense inputs should be a matrix but received shape ", 422fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower indices_list_in[i].shape().DebugString(), " at position ", i)); 423fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower OP_REQUIRES(context, dense_list_in[i].dim_size(0) == batch_size, 424fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower errors::InvalidArgument("Expected batch size ", batch_size, 425fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower " got ", dense_list_in[i].dim_size(0), 426fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower " at dense tensor ", i)); 427fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower } 428fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower } 429fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower 430fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower // Calculate the batch size from either the shapes input or the dense input. 431fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower int64 CalculateBatchSize(const OpInputList& shapes_list_in, 432fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower const OpInputList& dense_list_in) { 433fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower if (shapes_list_in.size() > 0) { 434fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower return shapes_list_in[0].vec<int64>()(0); 435fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower } 436fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower 437fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower if (dense_list_in.size() > 0) { 438fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower return dense_list_in[0].dim_size(0); 439fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower } 440fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower 441fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower return 0; 442fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower } 443fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower 444fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower // Generate the columns given the sparse and dense inputs. 445fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower std::vector<std::unique_ptr<ColumnInterface<InternalType>>> 446fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower GenerateColumnsFromInput(const OpInputList& indices_list_in, 447fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower const OpInputList& values_list_in, 448fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower const OpInputList& shapes_list_in, 449fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower const OpInputList& dense_list_in) { 450fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower std::vector<std::unique_ptr<ColumnInterface<InternalType>>> columns; 451fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower const int64 batch_size = CalculateBatchSize(shapes_list_in, dense_list_in); 452fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower const int64 number_of_columns = shapes_list_in.size(); 453fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower 454fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower std::vector<std::vector<int64>> feature_counts(number_of_columns, 455fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower std::vector<int64>()); 456fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower std::vector<std::vector<int64>> feature_start_indices(number_of_columns, 457fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower std::vector<int64>()); 458fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower 459fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower ExtractFeatureData(indices_list_in, batch_size, &feature_counts, 460fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower &feature_start_indices); 461fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower 462eb10a4c494d95e7c17ddc44ef35197d08f2f6b33A. Unique TensorFlower columns.reserve(values_list_in.size()); 463fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower for (int i = 0; i < values_list_in.size(); ++i) { 464fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower columns.emplace_back(new SparseTensorColumn<InternalType>( 465fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower values_list_in[i], std::move(feature_counts[i]), 466fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower std::move(feature_start_indices[i]))); 467fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower } 468fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower for (int i = 0; i < dense_list_in.size(); ++i) { 469fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower columns.emplace_back( 470fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower new DenseTensorColumn<InternalType>(dense_list_in[i])); 471fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower } 472fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower 473fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower return columns; 474fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower } 475fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower 476fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower // Extracts data about the features and populates feature data. 477fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower void ExtractFeatureData( 478fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower const OpInputList& indices_list_in, int64 batch_size, 479fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower std::vector<std::vector<int64>>* feature_counts, 480fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower std::vector<std::vector<int64>>* feature_start_indices) { 481fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower gtl::InlinedVector<int64, 8> current_row(indices_list_in.size(), 0); 482fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower for (int b = 0; b < batch_size; b++) { 483fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower for (int i = 0; i < indices_list_in.size(); i++) { 484fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower const auto indices = indices_list_in[i].matrix<int64>(); 485fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower int64 feature_count = 0; 486fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower int64 start_index = current_row[i]; 487fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower // Loops until we reach next batch index for current feature column. 488fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower while (current_row[i] < indices_list_in[i].dim_size(0) && 489fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower indices(current_row[i], 0) == b) { 490fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower feature_count++; 491fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower current_row[i]++; 492fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower } 493fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower (*feature_counts)[i].push_back(feature_count); 494fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower (*feature_start_indices)[i].push_back(start_index); 495fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower } 496fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower } 497fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower } 498fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower 499fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower // Allocates output tensors with proper size and sets the shape tensor of 500fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower // the output SparseTensor. 501fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower // It also output_start_indices which contains the start indices for each 502fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower // input in the output SparseTensor. 503fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower void CreateOutputTensors( 504fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>& 505fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower columns, 506fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower int64 batch_size, OpKernelContext* context, Tensor** indices_out, 507fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower Tensor** values_out, Tensor** shape_out, 508fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower std::vector<int64>* output_start_indices) { 509fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower // Calculates dimensions for output tensors. 510fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower int64 cross_count_total = 0; 511fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower int64 max_cross_count = 0; 512fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower for (int64 b = 0; b < batch_size; b++) { 513fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower // For each input, sets starting indices in output SparseTensor 514fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower (*output_start_indices)[b] = cross_count_total; 515fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower const auto cross_count = CrossCountByBatchIndex(columns, b); 516fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower max_cross_count = std::max(max_cross_count, cross_count); 517fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower cross_count_total += cross_count; 518fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower } 519fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower 520fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower // Allocates tensors. 521fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower OP_REQUIRES_OK(context, 522fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower context->allocate_output( 523fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower 0, TensorShape({cross_count_total, 2}), indices_out)); 524fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower OP_REQUIRES_OK(context, 525fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower context->allocate_output(1, TensorShape({cross_count_total}), 526fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower values_out)); 527fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower OP_REQUIRES_OK(context, 528fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower context->allocate_output(2, TensorShape({2}), shape_out)); 529fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower 530fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower // Sets shape. 531fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower auto shape_vec = (*shape_out)->vec<int64>(); 532fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower shape_vec(0) = batch_size; 533fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower shape_vec(1) = max_cross_count; 534fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower } 535fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower 536fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower // Returns number of crosses for a given batch_index 537fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower int64 CrossCountByBatchIndex( 538fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>& 539fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower columns, 540fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower int batch_index) { 541fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower int64 cross_count = 1; 542fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower for (int i = 0; i < columns.size(); i++) { 543fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower const auto feature_count = columns[i]->FeatureCount(batch_index); 544fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower // If one column is missing any feature, there won't be any cross. 545fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower if (feature_count == 0) { 546fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower return 0; 547fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower } 548fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower cross_count *= feature_count; 549fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower } 550fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower return cross_count; 551fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower } 552fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower int64 num_buckets_; 553fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower uint64 hash_key_; 554fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower}; 555fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower 556fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlowerREGISTER_KERNEL_BUILDER(Name("SparseCross") 557fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower .Device(DEVICE_CPU) 558fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower .TypeConstraint<string>("out_type") 559fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower .TypeConstraint<string>("internal_type"), 560fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower SparseCrossOp<false, StringPiece>); 561fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower 562fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlowerREGISTER_KERNEL_BUILDER(Name("SparseCross") 563fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower .Device(DEVICE_CPU) 564fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower .TypeConstraint<string>("out_type") 565fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower .TypeConstraint<int64>("internal_type"), 566fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower SparseCrossOp<false, string>); 567fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower 568fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlowerREGISTER_KERNEL_BUILDER(Name("SparseCross") 569fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower .Device(DEVICE_CPU) 570fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower .TypeConstraint<int64>("out_type") 571fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower .TypeConstraint<string>("internal_type"), 572fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower SparseCrossOp<true, int64>); 573fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower 574fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlowerREGISTER_KERNEL_BUILDER(Name("SparseCross") 575fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower .Device(DEVICE_CPU) 576fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower .TypeConstraint<int64>("out_type") 577fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower .TypeConstraint<int64>("internal_type"), 578fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower SparseCrossOp<true, int64>); 579fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower 580fd85d42d5d8203b4ce17c2825fd003411c78fbe6A. Unique TensorFlower} // namespace tensorflow 581