1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// Op that looks up items from a sparse tensor in an embedding matrix.
17// The sparse lookup tensor is represented by three individual tensors: lookup,
18// indices, and dense_shape. The representation assume that the corresponding
19// dense tensor would satisfy:
20//   * dense.shape = dense_shape
21//   * dense[tuple(indices[i])] = lookup[i]
22//
23// By convention, indices should be sorted.
24//
25// Options:
26//   combiner: The reduction op (SUM, MEAN, SQRTN).
27//     * SUM computes the weighted sum of the embedding results.
28//     * MEAN is the weighted sum divided by the total weight.
29//     * SQRTN is the weighted sum divided by the square root of the sum of the
30//       squares of the weights.
31//
32// Input:
33//     Tensor[0]: Ids to lookup, dim.size == 1, int32.
34//     Tensor[1]: Indices, int32.
35//     Tensor[2]: Dense shape, int32.
36//     Tensor[3]: Weights to use for aggregation, float.
37//     Tensor[4]: Params, a matrix of multi-dimensional items,
38//                dim.size >= 2, float.
39//
40// Output:
41//   A (dense) tensor representing the combined embeddings for the sparse ids.
42//   For each row in the sparse tensor represented by (lookup, indices, shape)
43//   the op looks up the embeddings for all ids in that row, multiplies them by
44//   the corresponding weight, and combines these embeddings as specified in the
45//   last dimension.
46//
47//   Output.dim = [l0, ... , ln-1, e1, ..., em]
48//   Where dense_shape == [l0, ..., ln] and Tensor[4].dim == [e0, e1, ..., em]
49//
50//   For instance, if params is a 10x20 matrix and ids, weights are:
51//
52//   [0, 0]: id 1, weight 2.0
53//   [0, 1]: id 3, weight 0.5
54//   [1, 0]: id 0, weight 1.0
55//   [2, 3]: id 1, weight 3.0
56//
57//   with combiner=MEAN, then the output will be a (3, 20) tensor where:
58//
59//   output[0, :] = (params[1, :] * 2.0 + params[3, :] * 0.5) / (2.0 + 0.5)
60//   output[1, :] = (params[0, :] * 1.0) / 1.0
61//   output[2, :] = (params[1, :] * 3.0) / 3.0
62//
63//   When indices are out of bound, the op will not succeed.
64
65#include <algorithm>
66#include <cmath>
67
68#include "tensorflow/contrib/lite/builtin_op_data.h"
69#include "tensorflow/contrib/lite/context.h"
70#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
71#include "tensorflow/contrib/lite/kernels/kernel_util.h"
72#include "tensorflow/contrib/lite/kernels/op_macros.h"
73
74namespace tflite {
75namespace ops {
76namespace builtin {
77
78namespace {
79
80TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
81  TF_LITE_ENSURE_EQ(context, NumInputs(node), 5);
82  TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
83
84  TfLiteTensor* ids = GetInput(context, node, 0);
85  TF_LITE_ENSURE_EQ(context, NumDimensions(ids), 1);
86  TF_LITE_ENSURE_EQ(context, ids->type, kTfLiteInt32);
87
88  TfLiteTensor* indices = GetInput(context, node, 1);
89  TF_LITE_ENSURE_EQ(context, NumDimensions(indices), 2);
90  TF_LITE_ENSURE_EQ(context, indices->type, kTfLiteInt32);
91
92  TfLiteTensor* shape = GetInput(context, node, 2);
93  TF_LITE_ENSURE_EQ(context, NumDimensions(shape), 1);
94  TF_LITE_ENSURE_EQ(context, shape->type, kTfLiteInt32);
95
96  TfLiteTensor* weights = GetInput(context, node, 3);
97  TF_LITE_ENSURE_EQ(context, NumDimensions(weights), 1);
98  TF_LITE_ENSURE_EQ(context, weights->type, kTfLiteFloat32);
99
100  TF_LITE_ENSURE_EQ(context, SizeOfDimension(indices, 0),
101                    SizeOfDimension(ids, 0));
102  TF_LITE_ENSURE_EQ(context, SizeOfDimension(indices, 0),
103                    SizeOfDimension(weights, 0));
104
105  TfLiteTensor* value = GetInput(context, node, 4);
106  TF_LITE_ENSURE(context, NumDimensions(value) >= 2);
107
108  // Mark the output as a dynamic tensor.
109  TfLiteTensor* output = GetOutput(context, node, 0);
110  TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32);
111  output->allocation_type = kTfLiteDynamic;
112
113  return kTfLiteOk;
114}
115
116void FinalizeAggregation(TfLiteCombinerType combiner, int num_elements,
117                         float current_total_weight,
118                         float current_squares_weight, int embedding_size,
119                         float* output) {
120  if (combiner != kTfLiteCombinerTypeSum && num_elements > 0) {
121    float multiplier = 1.0;
122    switch (combiner) {
123      case kTfLiteCombinerTypeMean:
124        multiplier = current_total_weight;
125        break;
126      case kTfLiteCombinerTypeSqrtn:
127        multiplier = std::sqrt(current_squares_weight);
128        break;
129      default:
130        break;
131    }
132    for (int k = 0; k < embedding_size; k++) {
133      output[k] /= multiplier;
134    }
135  }
136}
137
138TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
139  auto* params =
140      reinterpret_cast<TfLiteEmbeddingLookupSparseParams*>(node->builtin_data);
141  TfLiteTensor* output = GetOutput(context, node, 0);
142  TfLiteTensor* ids = GetInput(context, node, 0);
143  TfLiteTensor* indices = GetInput(context, node, 1);
144  TfLiteTensor* dense_shape = GetInput(context, node, 2);
145  TfLiteTensor* weights = GetInput(context, node, 3);
146  TfLiteTensor* value = GetInput(context, node, 4);
147
148  const int lookup_rank = SizeOfDimension(indices, 1);
149  const int embedding_rank = NumDimensions(value);
150  const int num_lookups = SizeOfDimension(ids, 0);
151  const int num_rows = SizeOfDimension(value, 0);
152
153  // The last dimension gets replaced by the embedding.
154  const int output_rank = (lookup_rank - 1) + (embedding_rank - 1);
155
156  // Make sure that the actual dense shape of the sparse tensor represented by
157  // (loopkup, indices, dense_shape) is consistent.
158  TF_LITE_ENSURE_EQ(context, SizeOfDimension(dense_shape, 0), lookup_rank);
159
160  // Resize output tensor.
161  TfLiteIntArray* output_shape = TfLiteIntArrayCreate(output_rank);
162  int k = 0;
163  int embedding_size = 1;
164  int lookup_size = 1;
165  for (int i = 0; i < lookup_rank - 1; i++, k++) {
166    const int dim = dense_shape->data.i32[i];
167    lookup_size *= dim;
168    output_shape->data[k] = dim;
169  }
170  for (int i = 1; i < embedding_rank; i++, k++) {
171    const int dim = SizeOfDimension(value, i);
172    embedding_size *= dim;
173    output_shape->data[k] = dim;
174  }
175  TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, output, output_shape));
176  const int output_size = lookup_size * embedding_size;
177  TfLiteTensorRealloc(output_size * sizeof(float), output);
178
179  tensor_utils::ZeroVector(output->data.f, output_size);
180
181  // Keep track of the current bucket for aggregation/combination.
182  int current_output_offset = 0;
183  float current_total_weight = 0.0;
184  float current_squares_weight = 0.0;
185  int num_elements = 0;
186
187  for (int i = 0; i < num_lookups; i++) {
188    int idx = ids->data.i32[i];
189    if (idx >= num_rows || idx < 0) {
190      context->ReportError(context,
191                           "Embedding Lookup Sparse: index out of bounds.");
192      return kTfLiteError;
193    }
194
195    // Check where we need to aggregate.
196    const int example_indices_offset = i * lookup_rank;
197    int output_bucket = 0;
198    int stride = 1;
199    for (int k = (lookup_rank - 1) - 1; k >= 0; k--) {
200      output_bucket += indices->data.i32[example_indices_offset + k] * stride;
201      stride *= dense_shape->data.i32[k];
202    }
203    const int output_offset = output_bucket * embedding_size;
204
205    // If we are in a new aggregation bucket and the combiner is not the sum,
206    // go back and finalize the result of the previous bucket.
207    if (output_offset != current_output_offset) {
208      FinalizeAggregation(params->combiner, num_elements, current_total_weight,
209                          current_squares_weight, embedding_size,
210                          &output->data.f[current_output_offset]);
211
212      // Track next bucket.
213      num_elements = 0;
214      current_total_weight = 0.0;
215      current_squares_weight = 0.0;
216      current_output_offset = output_offset;
217    }
218
219    // Add element to aggregation.
220    ++num_elements;
221    const int example_embedding_offset = idx * embedding_size;
222    const float w = weights->data.f[i];
223    current_squares_weight += w * w;
224    current_total_weight += w;
225    for (int k = 0; k < embedding_size; k++) {
226      output->data.f[current_output_offset + k] +=
227          (value->data.f[example_embedding_offset + k] * w);
228    }
229  }
230
231  // Finalize last bucket.
232  FinalizeAggregation(params->combiner, num_elements, current_total_weight,
233                      current_squares_weight, embedding_size,
234                      &output->data.f[current_output_offset]);
235
236  return kTfLiteOk;
237}
238
239}  // namespace
240
241TfLiteRegistration* Register_EMBEDDING_LOOKUP_SPARSE() {
242  static TfLiteRegistration r = {nullptr, nullptr, Prepare, Eval};
243  return &r;
244}
245
246}  // namespace builtin
247}  // namespace ops
248}  // namespace tflite
249