reference_util.cc revision f89185b655d380074c3e1e932e90d80cd2b01241
11e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
21e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
31e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsLicensed under the Apache License, Version 2.0 (the "License");
41e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsyou may not use this file except in compliance with the License.
51e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsYou may obtain a copy of the License at
61e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
71e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    http://www.apache.org/licenses/LICENSE-2.0
81e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
91e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsUnless required by applicable law or agreed to in writing, software
101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsdistributed under the License is distributed on an "AS IS" BASIS,
111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsSee the License for the specific language governing permissions and
131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinslimitations under the License.
141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins==============================================================================*/
151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/reference_util.h"
171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include <array>
191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/client/computation_builder.h"
211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h"
221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/window_util.h"
231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/xla_data.pb.h"
241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/lib/math/math_util.h"
251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/platform/logging.h"
261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsnamespace xla {
281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::TransposeArray2D(
301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array2D<float>& operand) {
311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto result = MakeUnique<Array2D<float>>(operand.width(), operand.height());
321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  for (int64 w = 0; w < operand.width(); ++w) {
331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    for (int64 h = 0; h < operand.height(); ++h) {
341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      (*result)(w, h) = operand(h, w);
351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return result;
391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MatmulArray2D(
421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array2D<float>& lhs, const Array2D<float>& rhs) {
431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  CHECK_EQ(lhs.width(), rhs.height());
441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int m = lhs.height();
451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int n = rhs.width();
461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int k = lhs.width();
471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto result = MakeUnique<Array2D<float>>(m, n);
481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // Because Eigen is a header-oriented library, make sure that the Eigen code
491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // is the same as the code used by the CPU backend (otherwise the linker will
501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // randomly pick *some* definition).
511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  __xla_cpu_runtime_EigenSingleThreadedMatMulF32(
521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      /*run_options_ptr=*/nullptr, result->data(), rhs.data(), lhs.data(), n, m,
531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      k,
541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      /*transpose_lhs=*/0,
551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      /*transpose_rhs=*/0);
561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return result;
571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<double>> ReferenceUtil::MatmulArray2D(
601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array2D<double>& lhs, const Array2D<double>& rhs) {
611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  CHECK_EQ(lhs.width(), rhs.height());
621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int m = lhs.height();
631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int n = rhs.width();
641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int k = lhs.width();
651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto result = MakeUnique<Array2D<double>>(m, n);
661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // Because Eigen is a header-oriented library, make sure that the Eigen code
671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // is the same as the code used by the CPU backend (otherwise the linker will
681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // randomly pick *some* definition).
691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  __xla_cpu_runtime_EigenSingleThreadedMatMulF64(
701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      /*run_options_ptr=*/nullptr, result->data(), rhs.data(), lhs.data(), n, m,
711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      k,
721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      /*transpose_lhs=*/0,
731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      /*transpose_rhs=*/0);
741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return result;
751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<double>> ReferenceUtil::Array2DF32ToF64(
781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array2D<float>& input) {
791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto result = MakeUnique<Array2D<double>>(input.height(), input.width());
801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  for (int64 rowno = 0; rowno < input.height(); ++rowno) {
811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    for (int64 colno = 0; colno < input.height(); ++colno) {
821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      (*result)(rowno, colno) = input(rowno, colno);
831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return result;
861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::ConvArray4D(
891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array4D<float>& lhs, const Array4D<float>& rhs,
901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    std::pair<int64, int64> kernel_stride, Padding padding) {
911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return ConvArray4DGeneralDimensions(
921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      lhs, rhs, kernel_stride, padding,
931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ComputationBuilder::CreateDefaultConvDimensionNumbers());
941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
967fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee/* static */ std::unique_ptr<Array4D<float>>
977fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong LeeReferenceUtil::SeparableConvArray4D(const Array4D<float>& input,
987fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee                                    const Array4D<float>& depthwise_weights,
997fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee                                    const Array4D<float>& pointwise_weights,
1007fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee                                    std::pair<int64, int64> kernel_stride,
1017fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee                                    Padding padding) {
1027fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee  const int64 depth_multiplier = depthwise_weights.planes();
1037fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee  CHECK_EQ(pointwise_weights.depth(), input.depth() * depth_multiplier);
1047fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee
1057fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee  // Combine the two weights by reducing the depth_multiplier, so that we can
1067fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee  // apply a single convolution on the combined weights.
1077fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee  Array4D<float> weights(pointwise_weights.planes(), input.depth(),
1087fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee                         depthwise_weights.height(), depthwise_weights.width());
1097fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee  for (int64 kx = 0; kx < depthwise_weights.width(); ++kx) {
1107fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee    for (int64 ky = 0; ky < depthwise_weights.height(); ++ky) {
1117fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee      for (int64 kz = 0; kz < input.depth(); ++kz) {
1127fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee        for (int64 out = 0; out < pointwise_weights.planes(); ++out) {
1137fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee          float weight = 0.0;
1147fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee          for (int64 depth = 0; depth < depth_multiplier; ++depth) {
1157fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee            weight +=
1167fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee                depthwise_weights(depth, kz, ky, kx) *
1177fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee                pointwise_weights(out, depth + kz * depth_multiplier, 0, 0);
1187fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee          }
1197fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee          weights(out, kz, ky, kx) = weight;
1207fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee        }
1217fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee      }
1227fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee    }
1237fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee  }
1247fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee
1257fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee  return ConvArray4D(input, weights, kernel_stride, padding);
1267fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee}
1277fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee
1281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ int64 ReferenceUtil::WindowCount(int64 unpadded_width,
1291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                              int64 window_len, int64 stride,
1301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                              Padding padding) {
1311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  if (padding == Padding::kValid) {
1321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    return window_util::StridedBound(unpadded_width, window_len, stride);
1331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
1341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return tensorflow::MathUtil::CeilOfRatio(unpadded_width, stride);
1351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
1361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static  */ std::unique_ptr<Array4D<float>> ReferenceUtil::ReduceWindow4DAdd(
1381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array4D<float>& operand, float init,
1391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const tensorflow::gtl::ArraySlice<int64>& window,
1401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding) {
1411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(),
1421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                 operand.n4()};
1431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding);
1441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  std::vector<int64> window_counts(window.size(), 0);
1461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  std::vector<int64> pad_low(window.size(), 0);
1471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  for (int64 i = 0; i < window.size(); ++i) {
1481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    window_counts[i] =
1491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        WindowCount(dim_lengths[i], window[i], stride[i], padding);
1501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    pad_low[i] = padding_both[i].first;
1511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
1521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto result = MakeUnique<Array4D<float>>(window_counts[0], window_counts[1],
1531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                           window_counts[2], window_counts[3]);
1541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // Do a full 4D reduce window.
1551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  for (int64 i0 = 0; i0 < window_counts[0]; ++i0) {
1561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    for (int64 i1 = 0; i1 < window_counts[1]; ++i1) {
1571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      for (int64 i2 = 0; i2 < window_counts[2]; ++i2) {
1581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        for (int64 i3 = 0; i3 < window_counts[3]; ++i3) {
1591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          int64 i0_base = i0 * stride[0] - pad_low[0];
1601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          int64 i1_base = i1 * stride[1] - pad_low[1];
1611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          int64 i2_base = i2 * stride[2] - pad_low[2];
1621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          int64 i3_base = i3 * stride[3] - pad_low[3];
1631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          float val = init;
1651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) {
1661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins            for (int64 i1_win = 0; i1_win < window[1]; ++i1_win) {
1671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins              for (int64 i2_win = 0; i2_win < window[2]; ++i2_win) {
1681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                for (int64 i3_win = 0; i3_win < window[3]; ++i3_win) {
1691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                  if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 &&
1701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      i2_base + i2_win >= 0 && i3_base + i3_win >= 0 &&
1711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      i0_base + i0_win < operand.n1() &&
1721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      i1_base + i1_win < operand.n2() &&
1731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      i2_base + i2_win < operand.n3() &&
1741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      i3_base + i3_win < operand.n4()) {
1751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                    val += operand(i0_base + i0_win, i1_base + i1_win,
1761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                   i2_base + i2_win, i3_base + i3_win);
1771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                  }
1781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                }
1791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins              }
1801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins            }
1811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          }
1821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          (*result)(i0, i1, i2, i3) = val;
1831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        }
1841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      }
1851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
1861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
1871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return result;
1881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
1891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static  */ std::unique_ptr<Array4D<float>>
1911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsReferenceUtil::SelectAndScatter4DGePlus(
1921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array4D<float>& operand, const Array4D<float>& source, float init,
1931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const tensorflow::gtl::ArraySlice<int64>& window,
1941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const tensorflow::gtl::ArraySlice<int64>& stride, bool same_padding) {
1951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  Padding padding = same_padding ? Padding::kSame : Padding::kValid;
1961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto result = MakeUnique<Array4D<float>>(operand.n1(), operand.n2(),
1971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                           operand.n3(), operand.n4());
1981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(),
1991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                 operand.n4()};
2001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding);
2011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // Fill the output, with the initial value.
2021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  result->Fill(init);
2031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
2041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  std::vector<int64> window_counts(window.size(), 0);
2051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  std::vector<int64> pad_low(window.size(), 0);
2061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  for (int64 i = 0; i < window.size(); ++i) {
2071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    window_counts[i] =
2081e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        WindowCount(dim_lengths[i], window[i], stride[i], padding);
2091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    pad_low[i] = padding_both[i].first;
2101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
2111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  CHECK_EQ(window_counts[0], source.n1());
2121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  CHECK_EQ(window_counts[1], source.n2());
2131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  CHECK_EQ(window_counts[2], source.n3());
2141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  CHECK_EQ(window_counts[3], source.n4());
2151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
2161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // Do a full 4D select and Scatter.
2171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  for (int64 i0 = 0; i0 < window_counts[0]; ++i0) {
2181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    for (int64 i1 = 0; i1 < window_counts[1]; ++i1) {
2191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      for (int64 i2 = 0; i2 < window_counts[2]; ++i2) {
2201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        for (int64 i3 = 0; i3 < window_counts[3]; ++i3) {
2211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          // Now we are inside a window and need to find the max and the argmax.
2221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          int64 i0_base = i0 * stride[0] - pad_low[0];
2231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          int64 i1_base = i1 * stride[1] - pad_low[1];
2241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          int64 i2_base = i2 * stride[2] - pad_low[2];
2251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          int64 i3_base = i3 * stride[3] - pad_low[3];
2261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          int64 scatter_0 = (i0_base >= 0) ? i0_base : 0;
2271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          int64 scatter_1 = (i1_base >= 0) ? i1_base : 0;
2281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          int64 scatter_2 = (i2_base >= 0) ? i2_base : 0;
2291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          int64 scatter_3 = (i3_base >= 0) ? i3_base : 0;
2301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          float val = operand(scatter_0, scatter_1, scatter_2, scatter_3);
2311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) {
2321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins            for (int64 i1_win = 0; i1_win < window[1]; ++i1_win) {
2331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins              for (int64 i2_win = 0; i2_win < window[2]; ++i2_win) {
2341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                for (int64 i3_win = 0; i3_win < window[3]; ++i3_win) {
2351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                  if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 &&
2361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      i2_base + i2_win >= 0 && i3_base + i3_win >= 0 &&
2371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      i0_base + i0_win < operand.n1() &&
2381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      i1_base + i1_win < operand.n2() &&
2391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      i2_base + i2_win < operand.n3() &&
2401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      i3_base + i3_win < operand.n4()) {
2411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                    float tmp = operand(i0_base + i0_win, i1_base + i1_win,
2421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                        i2_base + i2_win, i3_base + i3_win);
2431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                    if (tmp >= val) {
2441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      val = tmp;
2451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      scatter_0 = i0_base + i0_win;
2461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      scatter_1 = i1_base + i1_win;
2471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      scatter_2 = i2_base + i2_win;
2481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      scatter_3 = i3_base + i3_win;
2491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                    }
2501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                  }
2511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                }
2521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins              }
2531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins            }
2541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          }
2551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          (*result)(scatter_0, scatter_1, scatter_2, scatter_3) +=
2561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins              source(i0, i1, i2, i3);
2571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        }
2581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      }
2591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
2601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
2611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return result;
2621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
2631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
2641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array4D<float>>
2651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsReferenceUtil::ConvArray4DGeneralDimensions(
2661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array4D<float>& lhs, const Array4D<float>& rhs,
2671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    std::pair<int64, int64> kernel_stride, Padding padding,
2681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    ConvolutionDimensionNumbers dimension_numbers) {
2691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return ConvArray4DGeneralDimensionsDilated(lhs, rhs, kernel_stride, padding,
2701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                             {1, 1}, {1, 1}, dimension_numbers);
2711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
2721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
2731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array4D<float>>
2741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsReferenceUtil::ConvArray4DGeneralDimensionsDilated(
2751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array4D<float>& lhs, const Array4D<float>& rhs,
2761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    std::pair<int64, int64> kernel_stride, Padding padding,
2771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    std::pair<int64, int64> lhs_dilation, std::pair<int64, int64> rhs_dilation,
2781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    ConvolutionDimensionNumbers dnums) {
2791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  std::array<int64, 4> lhs_dimensions{{lhs.n1(), lhs.n2(), lhs.n3(), lhs.n4()}};
2801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  std::array<int64, 4> rhs_dimensions{{rhs.n1(), rhs.n2(), rhs.n3(), rhs.n4()}};
2811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
2821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  const int64 ksy = kernel_stride.first;
2831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  const int64 ksx = kernel_stride.second;
2841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  const int64 dy = lhs_dilation.first;
2851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  const int64 dx = lhs_dilation.second;
2861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  const int64 dky = rhs_dilation.first;
2871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  const int64 dkx = rhs_dilation.second;
2881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  CHECK_GE(dky, 1);
2891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  CHECK_GE(dkx, 1);
2901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  CHECK_GE(dy, 1);
2911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  CHECK_GE(dx, 1);
2921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
2931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // Get all dimension sizes in lhs and rhs based on the given convolution
2941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // dimension configuration.
2951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  const int64 ix = window_util::DilatedBound(
2961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      lhs_dimensions[dnums.spatial_dimensions(1)], dx);
2971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  const int64 iy = window_util::DilatedBound(
2981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      lhs_dimensions[dnums.spatial_dimensions(0)], dy);
2991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  const int64 iz = lhs_dimensions[dnums.feature_dimension()];
3001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  const int64 samples = lhs_dimensions[dnums.batch_dimension()];
3011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  const int64 kx = window_util::DilatedBound(
3021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      rhs_dimensions[dnums.kernel_spatial_dimensions(1)], dkx);
3031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  const int64 ky = window_util::DilatedBound(
3041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      rhs_dimensions[dnums.kernel_spatial_dimensions(0)], dky);
3051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  const int64 oz = rhs_dimensions[dnums.kernel_output_feature_dimension()];
3061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  {
3071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const int64 kiz = rhs_dimensions[dnums.kernel_input_feature_dimension()];
3081e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    CHECK_EQ(kiz, iz);
3091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
3101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
3111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  if (padding == Padding::kSame) {
3121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    // We reject same padding with kernel striding, since it's somewhat
3131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    // nonsensical. We can always follow up to implement this with the desired
3141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    // semantics if anybody actually uses it.
3151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    CHECK_EQ(1, ksy);
3161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    CHECK_EQ(1, ksx);
3171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
3181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
3191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  const int64 ox =
3201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      padding == Padding::kSame ? ix : window_util::StridedBound(ix, kx, ksx);
3211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  const int64 oy =
3221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      padding == Padding::kSame ? iy : window_util::StridedBound(iy, ky, ksy);
3231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  const int64 istartx =
3241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      padding == Padding::kValid ? 0 : kx % 2 == 0 ? -(kx / 2 - 1) : -kx / 2;
3251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  const int64 istarty =
3261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      padding == Padding::kValid ? 0 : ky % 2 == 0 ? -(ky / 2 - 1) : -ky / 2;
3271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // Create the output result array and reset the values to 0.
3281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  std::array<int64, 4> result_dimensions;
3291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  result_dimensions[dnums.batch_dimension()] = samples;
3301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  result_dimensions[dnums.feature_dimension()] = oz;
3311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  result_dimensions[dnums.spatial_dimensions(0)] = oy;
3321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  result_dimensions[dnums.spatial_dimensions(1)] = ox;
3331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto result =
3341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      MakeUnique<Array4D<float>>(result_dimensions[0], result_dimensions[1],
3351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                 result_dimensions[2], result_dimensions[3]);
3361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  result->Fill(0.0);
3371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
3387135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar  const auto is_int32 = [](int64 x) {
3397135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar    return x >= std::numeric_limits<int32>::min() &&
3407135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar           x <= std::numeric_limits<int32>::max();
3417135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar  };
3427135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar
3437135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar  // 64-bit idiv/mod are much more expensive x86-64 than 32-bit idiv/imod (at
3447135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar  // least on x86-64), so we avoid them where possible.
3457135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar  const auto fast_idiv64 = [&](int64 a, int64 b) {
3467135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar    if (is_int32(a) && is_int32(b)) {
3477135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar      return static_cast<int64>(static_cast<int32>(a) / static_cast<int32>(b));
3487135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar    }
3497135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar    return a / b;
3507135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar  };
3517135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar  const auto fast_imod64 = [&](int64 a, int64 b) {
3527135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar    if (is_int32(a) && is_int32(b)) {
3537135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar      return static_cast<int64>(static_cast<int32>(a) % static_cast<int32>(b));
3547135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar    }
3557135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar    return a % b;
3567135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar  };
3577135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar
3581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // Lambda to access the lhs operand at the given 4D index.
3591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  const auto lhs_element = [&](int64 batch, int64 feature, int64 height,
3601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                               int64 width) {
3617135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar    if (fast_imod64(height, dy) != 0 || fast_imod64(width, dx) != 0) {
3621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      return 0.0f;
3631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
3641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
3651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    std::array<int64, 4> index;
3661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    index[dnums.batch_dimension()] = batch;
3671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    index[dnums.feature_dimension()] = feature;
3687135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar    index[dnums.spatial_dimensions(0)] = fast_idiv64(height, dy);
3697135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar    index[dnums.spatial_dimensions(1)] = fast_idiv64(width, dx);
3701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    return lhs(index[0], index[1], index[2], index[3]);
3711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  };
3721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
373f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar  // Lambda to access the rhs operand at the given 4D index.  height_over_dky
374f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar  // should be equal to height / dky, and width_over_dkx should be equal to
375f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar  // width / dkx.  (This is an optimization to avoid doing divisions.)
376f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar  const auto rhs_element = [&](
377f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar      int64 kernel_output_feature, int64 kernel_input_feature, int64 height,
378f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar      int64 width, int64 height_over_dky, int64 width_over_dkx) {
379f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar    DCHECK_EQ(height % dky, 0);
380f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar    DCHECK_EQ(width % dkx, 0);
381f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar    DCHECK_EQ(height / dky, height_over_dky);
382f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar    DCHECK_EQ(width / dkx, width_over_dkx);
383f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar
3841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    std::array<int64, 4> index;
3851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    index[dnums.kernel_output_feature_dimension()] = kernel_output_feature;
3861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    index[dnums.kernel_input_feature_dimension()] = kernel_input_feature;
387f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar    index[dnums.kernel_spatial_dimensions(0)] = height_over_dky;
388f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar    index[dnums.kernel_spatial_dimensions(1)] = width_over_dkx;
3891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    return rhs(index[0], index[1], index[2], index[3]);
3901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  };
3911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
3921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // Lambda to access the result data at the given 4D index.
3931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  const auto result_element = [&](int64 batch, int64 kernel_output_feature,
3941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                  int64 height, int64 width) -> float& {
3951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    std::array<int64, 4> index;
3961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    index[dnums.batch_dimension()] = batch;
3971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    index[dnums.feature_dimension()] = kernel_output_feature;
3981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    index[dnums.spatial_dimensions(0)] = height;
3991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    index[dnums.spatial_dimensions(1)] = width;
4001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    return (*result)(index[0], index[1], index[2], index[3]);
4011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  };
4021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
4031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  for (int64 oyi = 0; oyi < oy; ++oyi) {
4041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    for (int64 oxi = 0; oxi < ox; ++oxi) {
4051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      for (int64 sample = 0; sample < samples; ++sample) {
4061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        for (int64 izi = 0; izi < iz; ++izi) {
4071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          for (int64 ozi = 0; ozi < oz; ++ozi) {
408f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar            for (int64 kyi = 0, kyi_over_dky = 0; kyi < ky;
409f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar                 kyi += dky, kyi_over_dky++) {
410f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar              for (int64 kxi = 0, kxi_over_dkx = 0; kxi < kx;
411f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar                   kxi += dkx, kxi_over_dkx++) {
4121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                int64 iyi = istarty + ksy * oyi + kyi;
4131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                int64 ixi = istartx + ksx * oxi + kxi;
4141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                float input = (iyi >= iy || ixi >= ix || iyi < 0 || ixi < 0)
4151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                  ? 0.0
4161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                  : lhs_element(sample, izi, iyi, ixi);
417f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar                float gain =
418f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar                    rhs_element(ozi, izi, kyi, kxi, kyi_over_dky, kxi_over_dkx);
4191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                float addend = input * gain;
4201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                result_element(sample, ozi, oyi, oxi) += addend;
4211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins              }
4221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins            }
4231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          }
4241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        }
4251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      }
4261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
4271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
4281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return result;
4291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
4301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
4311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<std::vector<float>>
4321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsReferenceUtil::ReduceToColArray2D(
4331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array2D<float>& matrix, float init,
4341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    std::function<float(float, float)> reduce_function) {
4351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 rows = matrix.height();
4361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 cols = matrix.width();
4371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto result = MakeUnique<std::vector<float>>();
4381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  for (int64 i = 0; i < rows; ++i) {
4391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    float acc = init;
4401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    for (int64 j = 0; j < cols; ++j) {
4411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      acc = reduce_function(acc, matrix(i, j));
4421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
4431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    result->push_back(acc);
4441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
4451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return result;
4461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
4471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
4481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<std::vector<float>>
4491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsReferenceUtil::ReduceToRowArray2D(
4501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array2D<float>& matrix, float init,
4511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    std::function<float(float, float)> reduce_function) {
4521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 rows = matrix.height();
4531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 cols = matrix.width();
4541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto result = MakeUnique<std::vector<float>>();
4551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  for (int64 i = 0; i < cols; ++i) {
4561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    float acc = init;
4571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    for (int64 j = 0; j < rows; ++j) {
4581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      acc = reduce_function(acc, matrix(j, i));
4591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
4601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    result->push_back(acc);
4611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
4621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return result;
4631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
4641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
4651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/*static*/ std::vector<float> ReferenceUtil::Reduce4DTo1D(
4661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array4D<float>& array, float init,
4671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    tensorflow::gtl::ArraySlice<int64> dims,
4681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    std::function<float(float, float)> reduce_function) {
4691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  std::vector<float> result;
4701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  CHECK_EQ(dims.size(), 3);
4711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  const std::set<int64> dim_set(dims.begin(), dims.end());
4721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  CHECK_EQ(dim_set.size(), 3);
4731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  for (int64 a0 = 0; a0 == 0 || (!dim_set.count(0) && a0 < array.n1()); ++a0) {
4741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    for (int64 a1 = 0; a1 == 0 || (!dim_set.count(1) && a1 < array.n2());
4751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins         ++a1) {
4761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      for (int64 a2 = 0; a2 == 0 || (!dim_set.count(2) && a2 < array.n3());
4771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins           ++a2) {
4781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        for (int64 a3 = 0; a3 == 0 || (!dim_set.count(3) && a3 < array.n4());
4791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins             ++a3) {
4801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          float accumulator = init;
4811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          for (int64 i0 = 0; i0 == 0 || (dim_set.count(0) && i0 < array.n1());
4821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins               ++i0) {
4831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins            for (int64 i1 = 0; i1 == 0 || (dim_set.count(1) && i1 < array.n2());
4841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                 ++i1) {
4851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins              for (int64 i2 = 0;
4861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                   i2 == 0 || (dim_set.count(2) && i2 < array.n3()); ++i2) {
4871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                for (int64 i3 = 0;
4881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                     i3 == 0 || (dim_set.count(3) && i3 < array.n4()); ++i3) {
4891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                  accumulator = reduce_function(
4901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      accumulator, array(a0 + i0, a1 + i1, a2 + i2, a3 + i3));
4911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                }
4921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins              }
4931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins            }
4941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          }
4951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          result.push_back(accumulator);
4961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        }
4971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      }
4981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
4991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
5001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return result;
5011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
5021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
5031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::Reduce3DTo2D(
5041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array3D<float>& array, float init,
5051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    tensorflow::gtl::ArraySlice<int64> dims,
5061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    std::function<float(float, float)> reduce_function) {
5071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  CHECK_EQ(dims.size(), 1);
5081e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 rows = dims[0] == 0 ? array.n2() : array.n1();
5091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 cols = dims[0] == 2 ? array.n2() : array.n3();
5101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto result = MakeUnique<Array2D<float>>(rows, cols);
5111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  result->Fill(init);
5121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  for (int i0 = 0; i0 < array.n1(); ++i0) {
5131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    for (int i1 = 0; i1 < array.n2(); ++i1) {
5141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      for (int i2 = 0; i2 < array.n3(); ++i2) {
5151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        int64 row = dims[0] == 0 ? i1 : i0;
5161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        int64 col = dims[0] == 2 ? i1 : i2;
5171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        (*result)(row, col) =
5181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins            reduce_function((*result)(row, col), array(i0, i1, i2));
5191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      }
5201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
5211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
5221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return result;
5231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
5241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
5251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MapArray2D(
5261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array2D<float>& matrix,
5271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const std::function<float(float)>& map_function) {
5281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 rows = matrix.height();
5291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 cols = matrix.width();
5301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto result = MakeUnique<Array2D<float>>(rows, cols);
5311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  for (int64 i = 0; i < rows; ++i) {
5321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    for (int64 j = 0; j < cols; ++j) {
5331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      (*result)(i, j) = map_function(matrix(i, j));
5341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
5351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
5361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return result;
5371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
5381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
5391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MapArray2D(
5401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array2D<float>& lhs, const Array2D<float>& rhs,
5411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const std::function<float(float, float)>& map_function) {
5421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  CHECK_EQ(lhs.height(), rhs.height());
5431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  CHECK_EQ(lhs.width(), rhs.width());
5441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 rows = lhs.height();
5451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 cols = rhs.width();
5461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto result = MakeUnique<Array2D<float>>(rows, cols);
5471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  for (int64 i = 0; i < rows; ++i) {
5481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    for (int64 j = 0; j < cols; ++j) {
5491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      (*result)(i, j) = map_function(lhs(i, j), rhs(i, j));
5501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
5511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
5521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return result;
5531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
5541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
5551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MapWithIndexArray2D(
5561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array2D<float>& matrix,
5571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const std::function<float(float, int64, int64)>& map_function) {
5581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 rows = matrix.height();
5591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 cols = matrix.width();
5601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto result = MakeUnique<Array2D<float>>(rows, cols);
5611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  for (int64 i = 0; i < rows; ++i) {
5621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    for (int64 j = 0; j < cols; ++j) {
5631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      (*result)(i, j) = map_function(matrix(i, j), i, j);
5641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
5651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
5661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return result;
5671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
5681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
5691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::PadArray2D(
5701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array2D<float>& operand, const PaddingConfig& padding,
5711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const float pad) {
5721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 in0 = operand.n1();
5731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 high_padding0 = padding.dimensions(0).edge_padding_high();
5741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 low_padding0 = padding.dimensions(0).edge_padding_low();
5751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 interior_padding0 = padding.dimensions(0).interior_padding();
5761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 out0 =
5771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      in0 + low_padding0 + high_padding0 + (in0 - 1) * interior_padding0;
5785aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan
5791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 in1 = operand.n2();
5801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 high_padding1 = padding.dimensions(1).edge_padding_high();
5811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 low_padding1 = padding.dimensions(1).edge_padding_low();
5821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 interior_padding1 = padding.dimensions(1).interior_padding();
5831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 out1 =
5841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      in1 + low_padding1 + high_padding1 + (in1 - 1) * interior_padding1;
5855aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan
5861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto result = MakeUnique<Array2D<float>>(out0, out1);
5871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  result->Fill(pad);
5885aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan  int64 o0 = low_padding0;
5895aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan  for (int64 i0 = 0; i0 < in0; ++i0) {
5905aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan    int64 o1 = low_padding1;
5915aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan    for (int64 i1 = 0; i1 < in1; ++i1) {
5925aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan      if (o0 >= 0 && o1 >= 0 && o0 < out0 && o1 < out1) {
5935aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan        (*result)(o0, o1) = operand(i0, i1);
5945aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan      }
5955aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan      o1 += interior_padding1 + 1;
5961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
5975aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan    o0 += interior_padding0 + 1;
5981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
5991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return result;
6001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
6011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
6021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}  // namespace xla
603