157600a8d7739f6fbea445c6efa1f29f12f769748Nupur Garg/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
257600a8d7739f6fbea445c6efa1f29f12f769748Nupur Garg
357600a8d7739f6fbea445c6efa1f29f12f769748Nupur GargLicensed under the Apache License, Version 2.0 (the "License");
457600a8d7739f6fbea445c6efa1f29f12f769748Nupur Gargyou may not use this file except in compliance with the License.
557600a8d7739f6fbea445c6efa1f29f12f769748Nupur GargYou may obtain a copy of the License at
657600a8d7739f6fbea445c6efa1f29f12f769748Nupur Garg
757600a8d7739f6fbea445c6efa1f29f12f769748Nupur Garg    http://www.apache.org/licenses/LICENSE-2.0
857600a8d7739f6fbea445c6efa1f29f12f769748Nupur Garg
957600a8d7739f6fbea445c6efa1f29f12f769748Nupur GargUnless required by applicable law or agreed to in writing, software
1057600a8d7739f6fbea445c6efa1f29f12f769748Nupur Gargdistributed under the License is distributed on an "AS IS" BASIS,
1157600a8d7739f6fbea445c6efa1f29f12f769748Nupur GargWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1257600a8d7739f6fbea445c6efa1f29f12f769748Nupur GargSee the License for the specific language governing permissions and
1357600a8d7739f6fbea445c6efa1f29f12f769748Nupur Garglimitations under the License.
1457600a8d7739f6fbea445c6efa1f29f12f769748Nupur Garg==============================================================================*/
1557600a8d7739f6fbea445c6efa1f29f12f769748Nupur Garg#include <string.h>
1657600a8d7739f6fbea445c6efa1f29f12f769748Nupur Garg#include <vector>
1757600a8d7739f6fbea445c6efa1f29f12f769748Nupur Garg#include "tensorflow/contrib/lite/builtin_op_data.h"
1857600a8d7739f6fbea445c6efa1f29f12f769748Nupur Garg#include "tensorflow/contrib/lite/context.h"
1957600a8d7739f6fbea445c6efa1f29f12f769748Nupur Garg#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
2057600a8d7739f6fbea445c6efa1f29f12f769748Nupur Garg#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
2157600a8d7739f6fbea445c6efa1f29f12f769748Nupur Garg#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
2257600a8d7739f6fbea445c6efa1f29f12f769748Nupur Garg#include "tensorflow/contrib/lite/kernels/kernel_util.h"
2357600a8d7739f6fbea445c6efa1f29f12f769748Nupur Garg#include "tensorflow/contrib/lite/kernels/op_macros.h"
2457600a8d7739f6fbea445c6efa1f29f12f769748Nupur Garg
2557600a8d7739f6fbea445c6efa1f29f12f769748Nupur Gargnamespace tflite {
2657600a8d7739f6fbea445c6efa1f29f12f769748Nupur Gargnamespace ops {
2757600a8d7739f6fbea445c6efa1f29f12f769748Nupur Gargnamespace builtin {
2857600a8d7739f6fbea445c6efa1f29f12f769748Nupur Gargnamespace pad {
2957600a8d7739f6fbea445c6efa1f29f12f769748Nupur Garg
3057600a8d7739f6fbea445c6efa1f29f12f769748Nupur Garg// This file has two implementations of Pad.
3157600a8d7739f6fbea445c6efa1f29f12f769748Nupur Gargenum KernelType {
3257600a8d7739f6fbea445c6efa1f29f12f769748Nupur Garg  kReference,
3357600a8d7739f6fbea445c6efa1f29f12f769748Nupur Garg  kGenericOptimized,
3457600a8d7739f6fbea445c6efa1f29f12f769748Nupur Garg};
3557600a8d7739f6fbea445c6efa1f29f12f769748Nupur Garg
3657600a8d7739f6fbea445c6efa1f29f12f769748Nupur Gargstruct PadContext {
3757600a8d7739f6fbea445c6efa1f29f12f769748Nupur Garg  PadContext(TfLiteContext* context, TfLiteNode* node) {
3857600a8d7739f6fbea445c6efa1f29f12f769748Nupur Garg    input = GetInput(context, node, 0);
3968ab9a99cf07ed5216d310417873572043803a1eNupur Garg    paddings = GetInput(context, node, 1);
4057600a8d7739f6fbea445c6efa1f29f12f769748Nupur Garg    output = GetOutput(context, node, 0);
4168ab9a99cf07ed5216d310417873572043803a1eNupur Garg    dims = NumDimensions(input);
4257600a8d7739f6fbea445c6efa1f29f12f769748Nupur Garg  }
4357600a8d7739f6fbea445c6efa1f29f12f769748Nupur Garg  TfLiteTensor* input;
4468ab9a99cf07ed5216d310417873572043803a1eNupur Garg  TfLiteTensor* paddings;
4557600a8d7739f6fbea445c6efa1f29f12f769748Nupur Garg  TfLiteTensor* output;
4668ab9a99cf07ed5216d310417873572043803a1eNupur Garg  int dims;
4757600a8d7739f6fbea445c6efa1f29f12f769748Nupur Garg};
4857600a8d7739f6fbea445c6efa1f29f12f769748Nupur Garg
4968ab9a99cf07ed5216d310417873572043803a1eNupur Garg// Resizes output array based on the input size and padding size. This function
5068ab9a99cf07ed5216d310417873572043803a1eNupur Garg// is callable from both Prepare() and Eval() as long as the caller ensures the
5168ab9a99cf07ed5216d310417873572043803a1eNupur Garg// paddings data is present.
5268ab9a99cf07ed5216d310417873572043803a1eNupur GargTfLiteStatus ResizeOutputTensor(TfLiteContext* context,
5368ab9a99cf07ed5216d310417873572043803a1eNupur Garg                                PadContext* op_context) {
5468ab9a99cf07ed5216d310417873572043803a1eNupur Garg  // Ensures the paddings array is dims x 2.
5568ab9a99cf07ed5216d310417873572043803a1eNupur Garg  TF_LITE_ENSURE_EQ(context, SizeOfDimension(op_context->paddings, 0),
5668ab9a99cf07ed5216d310417873572043803a1eNupur Garg                    op_context->dims);
5768ab9a99cf07ed5216d310417873572043803a1eNupur Garg  TF_LITE_ENSURE_EQ(context, SizeOfDimension(op_context->paddings, 1), 2);
5868ab9a99cf07ed5216d310417873572043803a1eNupur Garg
5968ab9a99cf07ed5216d310417873572043803a1eNupur Garg  // Determines the size of the output tensor.
6003bb1c4a6014fdf3f10f301f093ec02d84f717c7Nupur Garg  TfLiteIntArray* input_size = op_context->input->dims;
6103bb1c4a6014fdf3f10f301f093ec02d84f717c7Nupur Garg  TfLiteIntArray* output_size = TfLiteIntArrayCopy(input_size);
6268ab9a99cf07ed5216d310417873572043803a1eNupur Garg  const int32* paddings_data = GetTensorData<int32>(op_context->paddings);
6368ab9a99cf07ed5216d310417873572043803a1eNupur Garg
6468ab9a99cf07ed5216d310417873572043803a1eNupur Garg  for (int idx = 0; idx < op_context->dims; ++idx) {
6568ab9a99cf07ed5216d310417873572043803a1eNupur Garg    int before_padding = *paddings_data++;
6668ab9a99cf07ed5216d310417873572043803a1eNupur Garg    int after_padding = *paddings_data++;
6768ab9a99cf07ed5216d310417873572043803a1eNupur Garg
6868ab9a99cf07ed5216d310417873572043803a1eNupur Garg    TF_LITE_ENSURE_MSG(context, (before_padding >= 0 && after_padding >= 0),
6968ab9a99cf07ed5216d310417873572043803a1eNupur Garg                       "Pad value has to be greater than equal to 0.");
7068ab9a99cf07ed5216d310417873572043803a1eNupur Garg
7168ab9a99cf07ed5216d310417873572043803a1eNupur Garg    output_size->data[idx] =
7268ab9a99cf07ed5216d310417873572043803a1eNupur Garg        (input_size->data[idx] + before_padding + after_padding);
7368ab9a99cf07ed5216d310417873572043803a1eNupur Garg  }
7468ab9a99cf07ed5216d310417873572043803a1eNupur Garg
7568ab9a99cf07ed5216d310417873572043803a1eNupur Garg  return context->ResizeTensor(context, op_context->output, output_size);
7668ab9a99cf07ed5216d310417873572043803a1eNupur Garg}
7768ab9a99cf07ed5216d310417873572043803a1eNupur Garg
7857600a8d7739f6fbea445c6efa1f29f12f769748Nupur GargTfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
7968ab9a99cf07ed5216d310417873572043803a1eNupur Garg  TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
8057600a8d7739f6fbea445c6efa1f29f12f769748Nupur Garg  TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
8157600a8d7739f6fbea445c6efa1f29f12f769748Nupur Garg
8257600a8d7739f6fbea445c6efa1f29f12f769748Nupur Garg  PadContext op_context(context, node);
837255b9819f72b681aa66876ef0bd5ddfe67099f4Nupur Garg  TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type);
8457600a8d7739f6fbea445c6efa1f29f12f769748Nupur Garg
8503bb1c4a6014fdf3f10f301f093ec02d84f717c7Nupur Garg  // TODO(nupurgarg): Our current implementations rely on the inputs being 4D.
8603bb1c4a6014fdf3f10f301f093ec02d84f717c7Nupur Garg  TF_LITE_ENSURE_EQ(context, op_context.dims, 4);
8703bb1c4a6014fdf3f10f301f093ec02d84f717c7Nupur Garg
8868ab9a99cf07ed5216d310417873572043803a1eNupur Garg  // Exit early if paddings is a non-const tensor. Set output tensor to
8968ab9a99cf07ed5216d310417873572043803a1eNupur Garg  // dynamic so output size can be determined in Eval.
9003bb1c4a6014fdf3f10f301f093ec02d84f717c7Nupur Garg  if (!IsConstantTensor(op_context.paddings)) {
9103bb1c4a6014fdf3f10f301f093ec02d84f717c7Nupur Garg    SetTensorToDynamic(op_context.output);
9268ab9a99cf07ed5216d310417873572043803a1eNupur Garg    return kTfLiteOk;
9357600a8d7739f6fbea445c6efa1f29f12f769748Nupur Garg  }
9468ab9a99cf07ed5216d310417873572043803a1eNupur Garg  return ResizeOutputTensor(context, &op_context);
9557600a8d7739f6fbea445c6efa1f29f12f769748Nupur Garg}
9657600a8d7739f6fbea445c6efa1f29f12f769748Nupur Garg
9757600a8d7739f6fbea445c6efa1f29f12f769748Nupur Gargtemplate <KernelType kernel_type>
9857600a8d7739f6fbea445c6efa1f29f12f769748Nupur GargTfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
9957600a8d7739f6fbea445c6efa1f29f12f769748Nupur Garg  PadContext op_context(context, node);
10057600a8d7739f6fbea445c6efa1f29f12f769748Nupur Garg
10168ab9a99cf07ed5216d310417873572043803a1eNupur Garg  // Resize the output tensor if the output tensor is dynamic.
10203bb1c4a6014fdf3f10f301f093ec02d84f717c7Nupur Garg  if (IsDynamicTensor(op_context.output)) {
10368ab9a99cf07ed5216d310417873572043803a1eNupur Garg    TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
10468ab9a99cf07ed5216d310417873572043803a1eNupur Garg  }
10568ab9a99cf07ed5216d310417873572043803a1eNupur Garg
10668ab9a99cf07ed5216d310417873572043803a1eNupur Garg  // TODO(nupurgarg): Change kernel implementation to take in int* instead of
10768ab9a99cf07ed5216d310417873572043803a1eNupur Garg  // vector<int> to remove malloc from Eval().
10868ab9a99cf07ed5216d310417873572043803a1eNupur Garg  // Create before and after padding arrays that are accepted by the kernel.
10968ab9a99cf07ed5216d310417873572043803a1eNupur Garg  std::vector<int> before_padding;
11068ab9a99cf07ed5216d310417873572043803a1eNupur Garg  std::vector<int> after_padding;
11168ab9a99cf07ed5216d310417873572043803a1eNupur Garg  const int32* paddings_data = GetTensorData<int32>(op_context.paddings);
11268ab9a99cf07ed5216d310417873572043803a1eNupur Garg
11368ab9a99cf07ed5216d310417873572043803a1eNupur Garg  // TODO(nupurgarg): Change kernel implementation to use padding arrays in
11468ab9a99cf07ed5216d310417873572043803a1eNupur Garg  // forward order (depth, width, height, batch).
11568ab9a99cf07ed5216d310417873572043803a1eNupur Garg  // Build paddings in order of int[] = {batch, height, width, depth} to match
11668ab9a99cf07ed5216d310417873572043803a1eNupur Garg  // kernel implementation of Pad in referenced_ops.h and optimized_ops.h.
11768ab9a99cf07ed5216d310417873572043803a1eNupur Garg  for (int idx = op_context.dims - 1; idx >= 0; --idx) {
11868ab9a99cf07ed5216d310417873572043803a1eNupur Garg    before_padding.push_back(paddings_data[idx * 2]);
11968ab9a99cf07ed5216d310417873572043803a1eNupur Garg    after_padding.push_back(paddings_data[idx * 2 + 1]);
12068ab9a99cf07ed5216d310417873572043803a1eNupur Garg  }
1217255b9819f72b681aa66876ef0bd5ddfe67099f4Nupur Garg
1227255b9819f72b681aa66876ef0bd5ddfe67099f4Nupur Garg#define TF_LITE_PAD(type, scalar)                                           \
1237255b9819f72b681aa66876ef0bd5ddfe67099f4Nupur Garg  type::Pad(GetTensorData<scalar>(op_context.input),                        \
12457600a8d7739f6fbea445c6efa1f29f12f769748Nupur Garg            GetTensorDims(op_context.input), before_padding, after_padding, \
1257255b9819f72b681aa66876ef0bd5ddfe67099f4Nupur Garg            GetTensorData<scalar>(op_context.output),                       \
12657600a8d7739f6fbea445c6efa1f29f12f769748Nupur Garg            GetTensorDims(op_context.output))
12757600a8d7739f6fbea445c6efa1f29f12f769748Nupur Garg
1287255b9819f72b681aa66876ef0bd5ddfe67099f4Nupur Garg  switch (op_context.input->type) {
1297255b9819f72b681aa66876ef0bd5ddfe67099f4Nupur Garg    case kTfLiteFloat32:
1307255b9819f72b681aa66876ef0bd5ddfe67099f4Nupur Garg      if (kernel_type == kReference) {
1317255b9819f72b681aa66876ef0bd5ddfe67099f4Nupur Garg        TF_LITE_PAD(reference_ops, float);
1327255b9819f72b681aa66876ef0bd5ddfe67099f4Nupur Garg      } else if (kernel_type == kGenericOptimized) {
1337255b9819f72b681aa66876ef0bd5ddfe67099f4Nupur Garg        TF_LITE_PAD(optimized_ops, float);
1347255b9819f72b681aa66876ef0bd5ddfe67099f4Nupur Garg      }
1357255b9819f72b681aa66876ef0bd5ddfe67099f4Nupur Garg      break;
1367255b9819f72b681aa66876ef0bd5ddfe67099f4Nupur Garg    case kTfLiteUInt8:
1377255b9819f72b681aa66876ef0bd5ddfe67099f4Nupur Garg      if (kernel_type == kReference) {
1387255b9819f72b681aa66876ef0bd5ddfe67099f4Nupur Garg        TF_LITE_PAD(reference_ops, uint8_t);
1397255b9819f72b681aa66876ef0bd5ddfe67099f4Nupur Garg      } else if (kernel_type == kGenericOptimized) {
1407255b9819f72b681aa66876ef0bd5ddfe67099f4Nupur Garg        TF_LITE_PAD(optimized_ops, uint8_t);
1417255b9819f72b681aa66876ef0bd5ddfe67099f4Nupur Garg      }
1427255b9819f72b681aa66876ef0bd5ddfe67099f4Nupur Garg      break;
1437255b9819f72b681aa66876ef0bd5ddfe67099f4Nupur Garg    case kTfLiteInt32:
1447255b9819f72b681aa66876ef0bd5ddfe67099f4Nupur Garg      if (kernel_type == kReference) {
1457255b9819f72b681aa66876ef0bd5ddfe67099f4Nupur Garg        TF_LITE_PAD(reference_ops, int32_t);
1467255b9819f72b681aa66876ef0bd5ddfe67099f4Nupur Garg      } else if (kernel_type == kGenericOptimized) {
1477255b9819f72b681aa66876ef0bd5ddfe67099f4Nupur Garg        TF_LITE_PAD(optimized_ops, int32_t);
1487255b9819f72b681aa66876ef0bd5ddfe67099f4Nupur Garg      }
1497255b9819f72b681aa66876ef0bd5ddfe67099f4Nupur Garg      break;
1507255b9819f72b681aa66876ef0bd5ddfe67099f4Nupur Garg    case kTfLiteInt64:
1517255b9819f72b681aa66876ef0bd5ddfe67099f4Nupur Garg      if (kernel_type == kReference) {
1527255b9819f72b681aa66876ef0bd5ddfe67099f4Nupur Garg        TF_LITE_PAD(reference_ops, int64_t);
1537255b9819f72b681aa66876ef0bd5ddfe67099f4Nupur Garg      } else if (kernel_type == kGenericOptimized) {
1547255b9819f72b681aa66876ef0bd5ddfe67099f4Nupur Garg        TF_LITE_PAD(optimized_ops, int64_t);
1557255b9819f72b681aa66876ef0bd5ddfe67099f4Nupur Garg      }
1567255b9819f72b681aa66876ef0bd5ddfe67099f4Nupur Garg      break;
1577255b9819f72b681aa66876ef0bd5ddfe67099f4Nupur Garg    default:
1587255b9819f72b681aa66876ef0bd5ddfe67099f4Nupur Garg      context->ReportError(context, "Type is currently not supported by Pad.");
1597255b9819f72b681aa66876ef0bd5ddfe67099f4Nupur Garg      return kTfLiteError;
16057600a8d7739f6fbea445c6efa1f29f12f769748Nupur Garg  }
1617255b9819f72b681aa66876ef0bd5ddfe67099f4Nupur Garg#undef TF_LITE_PAD
16257600a8d7739f6fbea445c6efa1f29f12f769748Nupur Garg  return kTfLiteOk;
16357600a8d7739f6fbea445c6efa1f29f12f769748Nupur Garg}
16457600a8d7739f6fbea445c6efa1f29f12f769748Nupur Garg
16557600a8d7739f6fbea445c6efa1f29f12f769748Nupur Garg}  // namespace pad
16657600a8d7739f6fbea445c6efa1f29f12f769748Nupur Garg
16757600a8d7739f6fbea445c6efa1f29f12f769748Nupur GargTfLiteRegistration* Register_PAD_REF() {
16857600a8d7739f6fbea445c6efa1f29f12f769748Nupur Garg  static TfLiteRegistration r = {nullptr, nullptr, pad::Prepare,
16957600a8d7739f6fbea445c6efa1f29f12f769748Nupur Garg                                 pad::Eval<pad::kReference>};
17057600a8d7739f6fbea445c6efa1f29f12f769748Nupur Garg  return &r;
17157600a8d7739f6fbea445c6efa1f29f12f769748Nupur Garg}
17257600a8d7739f6fbea445c6efa1f29f12f769748Nupur Garg
17357600a8d7739f6fbea445c6efa1f29f12f769748Nupur GargTfLiteRegistration* Register_PAD_GENERIC_OPT() {
17457600a8d7739f6fbea445c6efa1f29f12f769748Nupur Garg  static TfLiteRegistration r = {nullptr, nullptr, pad::Prepare,
17557600a8d7739f6fbea445c6efa1f29f12f769748Nupur Garg                                 pad::Eval<pad::kGenericOptimized>};
17657600a8d7739f6fbea445c6efa1f29f12f769748Nupur Garg  return &r;
17757600a8d7739f6fbea445c6efa1f29f12f769748Nupur Garg}
17857600a8d7739f6fbea445c6efa1f29f12f769748Nupur Garg
1794463d105a8a4a83642b9709ba79310e8f4ddf577A. Unique TensorFlowerTfLiteRegistration* Register_PAD() { return Register_PAD_GENERIC_OPT(); }
18057600a8d7739f6fbea445c6efa1f29f12f769748Nupur Garg
18157600a8d7739f6fbea445c6efa1f29f12f769748Nupur Garg}  // namespace builtin
18257600a8d7739f6fbea445c6efa1f29f12f769748Nupur Garg}  // namespace ops
18357600a8d7739f6fbea445c6efa1f29f12f769748Nupur Garg}  // namespace tflite
184