1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16// Op that looks up items from a sparse tensor in an embedding matrix. 17// The sparse lookup tensor is represented by three individual tensors: lookup, 18// indices, and dense_shape. The representation assume that the corresponding 19// dense tensor would satisfy: 20// * dense.shape = dense_shape 21// * dense[tuple(indices[i])] = lookup[i] 22// 23// By convention, indices should be sorted. 24// 25// Options: 26// combiner: The reduction op (SUM, MEAN, SQRTN). 27// * SUM computes the weighted sum of the embedding results. 28// * MEAN is the weighted sum divided by the total weight. 29// * SQRTN is the weighted sum divided by the square root of the sum of the 30// squares of the weights. 31// 32// Input: 33// Tensor[0]: Ids to lookup, dim.size == 1, int32. 34// Tensor[1]: Indices, int32. 35// Tensor[2]: Dense shape, int32. 36// Tensor[3]: Weights to use for aggregation, float. 37// Tensor[4]: Params, a matrix of multi-dimensional items, 38// dim.size >= 2, float. 39// 40// Output: 41// A (dense) tensor representing the combined embeddings for the sparse ids. 42// For each row in the sparse tensor represented by (lookup, indices, shape) 43// the op looks up the embeddings for all ids in that row, multiplies them by 44// the corresponding weight, and combines these embeddings as specified in the 45// last dimension. 46// 47// Output.dim = [l0, ... , ln-1, e1, ..., em] 48// Where dense_shape == [l0, ..., ln] and Tensor[4].dim == [e0, e1, ..., em] 49// 50// For instance, if params is a 10x20 matrix and ids, weights are: 51// 52// [0, 0]: id 1, weight 2.0 53// [0, 1]: id 3, weight 0.5 54// [1, 0]: id 0, weight 1.0 55// [2, 3]: id 1, weight 3.0 56// 57// with combiner=MEAN, then the output will be a (3, 20) tensor where: 58// 59// output[0, :] = (params[1, :] * 2.0 + params[3, :] * 0.5) / (2.0 + 0.5) 60// output[1, :] = (params[0, :] * 1.0) / 1.0 61// output[2, :] = (params[1, :] * 3.0) / 3.0 62// 63// When indices are out of bound, the op will not succeed. 64 65#include <algorithm> 66#include <cmath> 67 68#include "tensorflow/contrib/lite/builtin_op_data.h" 69#include "tensorflow/contrib/lite/context.h" 70#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" 71#include "tensorflow/contrib/lite/kernels/kernel_util.h" 72#include "tensorflow/contrib/lite/kernels/op_macros.h" 73 74namespace tflite { 75namespace ops { 76namespace builtin { 77 78namespace { 79 80TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { 81 TF_LITE_ENSURE_EQ(context, NumInputs(node), 5); 82 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); 83 84 TfLiteTensor* ids = GetInput(context, node, 0); 85 TF_LITE_ENSURE_EQ(context, NumDimensions(ids), 1); 86 TF_LITE_ENSURE_EQ(context, ids->type, kTfLiteInt32); 87 88 TfLiteTensor* indices = GetInput(context, node, 1); 89 TF_LITE_ENSURE_EQ(context, NumDimensions(indices), 2); 90 TF_LITE_ENSURE_EQ(context, indices->type, kTfLiteInt32); 91 92 TfLiteTensor* shape = GetInput(context, node, 2); 93 TF_LITE_ENSURE_EQ(context, NumDimensions(shape), 1); 94 TF_LITE_ENSURE_EQ(context, shape->type, kTfLiteInt32); 95 96 TfLiteTensor* weights = GetInput(context, node, 3); 97 TF_LITE_ENSURE_EQ(context, NumDimensions(weights), 1); 98 TF_LITE_ENSURE_EQ(context, weights->type, kTfLiteFloat32); 99 100 TF_LITE_ENSURE_EQ(context, SizeOfDimension(indices, 0), 101 SizeOfDimension(ids, 0)); 102 TF_LITE_ENSURE_EQ(context, SizeOfDimension(indices, 0), 103 SizeOfDimension(weights, 0)); 104 105 TfLiteTensor* value = GetInput(context, node, 4); 106 TF_LITE_ENSURE(context, NumDimensions(value) >= 2); 107 108 // Mark the output as a dynamic tensor. 109 TfLiteTensor* output = GetOutput(context, node, 0); 110 TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32); 111 output->allocation_type = kTfLiteDynamic; 112 113 return kTfLiteOk; 114} 115 116void FinalizeAggregation(TfLiteCombinerType combiner, int num_elements, 117 float current_total_weight, 118 float current_squares_weight, int embedding_size, 119 float* output) { 120 if (combiner != kTfLiteCombinerTypeSum && num_elements > 0) { 121 float multiplier = 1.0; 122 switch (combiner) { 123 case kTfLiteCombinerTypeMean: 124 multiplier = current_total_weight; 125 break; 126 case kTfLiteCombinerTypeSqrtn: 127 multiplier = std::sqrt(current_squares_weight); 128 break; 129 default: 130 break; 131 } 132 for (int k = 0; k < embedding_size; k++) { 133 output[k] /= multiplier; 134 } 135 } 136} 137 138TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { 139 auto* params = 140 reinterpret_cast<TfLiteEmbeddingLookupSparseParams*>(node->builtin_data); 141 TfLiteTensor* output = GetOutput(context, node, 0); 142 TfLiteTensor* ids = GetInput(context, node, 0); 143 TfLiteTensor* indices = GetInput(context, node, 1); 144 TfLiteTensor* dense_shape = GetInput(context, node, 2); 145 TfLiteTensor* weights = GetInput(context, node, 3); 146 TfLiteTensor* value = GetInput(context, node, 4); 147 148 const int lookup_rank = SizeOfDimension(indices, 1); 149 const int embedding_rank = NumDimensions(value); 150 const int num_lookups = SizeOfDimension(ids, 0); 151 const int num_rows = SizeOfDimension(value, 0); 152 153 // The last dimension gets replaced by the embedding. 154 const int output_rank = (lookup_rank - 1) + (embedding_rank - 1); 155 156 // Make sure that the actual dense shape of the sparse tensor represented by 157 // (loopkup, indices, dense_shape) is consistent. 158 TF_LITE_ENSURE_EQ(context, SizeOfDimension(dense_shape, 0), lookup_rank); 159 160 // Resize output tensor. 161 TfLiteIntArray* output_shape = TfLiteIntArrayCreate(output_rank); 162 int k = 0; 163 int embedding_size = 1; 164 int lookup_size = 1; 165 for (int i = 0; i < lookup_rank - 1; i++, k++) { 166 const int dim = dense_shape->data.i32[i]; 167 lookup_size *= dim; 168 output_shape->data[k] = dim; 169 } 170 for (int i = 1; i < embedding_rank; i++, k++) { 171 const int dim = SizeOfDimension(value, i); 172 embedding_size *= dim; 173 output_shape->data[k] = dim; 174 } 175 TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, output, output_shape)); 176 const int output_size = lookup_size * embedding_size; 177 TfLiteTensorRealloc(output_size * sizeof(float), output); 178 179 tensor_utils::ZeroVector(output->data.f, output_size); 180 181 // Keep track of the current bucket for aggregation/combination. 182 int current_output_offset = 0; 183 float current_total_weight = 0.0; 184 float current_squares_weight = 0.0; 185 int num_elements = 0; 186 187 for (int i = 0; i < num_lookups; i++) { 188 int idx = ids->data.i32[i]; 189 if (idx >= num_rows || idx < 0) { 190 context->ReportError(context, 191 "Embedding Lookup Sparse: index out of bounds."); 192 return kTfLiteError; 193 } 194 195 // Check where we need to aggregate. 196 const int example_indices_offset = i * lookup_rank; 197 int output_bucket = 0; 198 int stride = 1; 199 for (int k = (lookup_rank - 1) - 1; k >= 0; k--) { 200 output_bucket += indices->data.i32[example_indices_offset + k] * stride; 201 stride *= dense_shape->data.i32[k]; 202 } 203 const int output_offset = output_bucket * embedding_size; 204 205 // If we are in a new aggregation bucket and the combiner is not the sum, 206 // go back and finalize the result of the previous bucket. 207 if (output_offset != current_output_offset) { 208 FinalizeAggregation(params->combiner, num_elements, current_total_weight, 209 current_squares_weight, embedding_size, 210 &output->data.f[current_output_offset]); 211 212 // Track next bucket. 213 num_elements = 0; 214 current_total_weight = 0.0; 215 current_squares_weight = 0.0; 216 current_output_offset = output_offset; 217 } 218 219 // Add element to aggregation. 220 ++num_elements; 221 const int example_embedding_offset = idx * embedding_size; 222 const float w = weights->data.f[i]; 223 current_squares_weight += w * w; 224 current_total_weight += w; 225 for (int k = 0; k < embedding_size; k++) { 226 output->data.f[current_output_offset + k] += 227 (value->data.f[example_embedding_offset + k] * w); 228 } 229 } 230 231 // Finalize last bucket. 232 FinalizeAggregation(params->combiner, num_elements, current_total_weight, 233 current_squares_weight, embedding_size, 234 &output->data.f[current_output_offset]); 235 236 return kTfLiteOk; 237} 238 239} // namespace 240 241TfLiteRegistration* Register_EMBEDDING_LOOKUP_SPARSE() { 242 static TfLiteRegistration r = {nullptr, nullptr, Prepare, Eval}; 243 return &r; 244} 245 246} // namespace builtin 247} // namespace ops 248} // namespace tflite 249