reference_util.cc revision 0034029ac66bb60f272fa1aae05eca5dd9d210d1
11e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
21e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
31e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsLicensed under the Apache License, Version 2.0 (the "License");
41e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsyou may not use this file except in compliance with the License.
51e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsYou may obtain a copy of the License at
61e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
71e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    http://www.apache.org/licenses/LICENSE-2.0
81e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
91e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsUnless required by applicable law or agreed to in writing, software
101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsdistributed under the License is distributed on an "AS IS" BASIS,
111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsSee the License for the specific language governing permissions and
131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinslimitations under the License.
141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins==============================================================================*/
151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/reference_util.h"
171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include <array>
19f4b8d21b8e41636b6e61f0a1de753430108d2ee7A. Unique TensorFlower#include <utility>
201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/client/computation_builder.h"
221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h"
231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/window_util.h"
241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/xla_data.pb.h"
251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/lib/math/math_util.h"
261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/platform/logging.h"
271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsnamespace xla {
291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::TransposeArray2D(
311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array2D<float>& operand) {
321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto result = MakeUnique<Array2D<float>>(operand.width(), operand.height());
331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  for (int64 w = 0; w < operand.width(); ++w) {
341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    for (int64 h = 0; h < operand.height(); ++h) {
351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      (*result)(w, h) = operand(h, w);
361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return result;
401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MatmulArray2D(
431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array2D<float>& lhs, const Array2D<float>& rhs) {
441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  CHECK_EQ(lhs.width(), rhs.height());
451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int m = lhs.height();
461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int n = rhs.width();
471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int k = lhs.width();
481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto result = MakeUnique<Array2D<float>>(m, n);
491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // Because Eigen is a header-oriented library, make sure that the Eigen code
501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // is the same as the code used by the CPU backend (otherwise the linker will
511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // randomly pick *some* definition).
521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  __xla_cpu_runtime_EigenSingleThreadedMatMulF32(
531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      /*run_options_ptr=*/nullptr, result->data(), rhs.data(), lhs.data(), n, m,
541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      k,
551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      /*transpose_lhs=*/0,
561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      /*transpose_rhs=*/0);
571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return result;
581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<double>> ReferenceUtil::MatmulArray2D(
611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array2D<double>& lhs, const Array2D<double>& rhs) {
621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  CHECK_EQ(lhs.width(), rhs.height());
631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int m = lhs.height();
641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int n = rhs.width();
651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int k = lhs.width();
661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto result = MakeUnique<Array2D<double>>(m, n);
671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // Because Eigen is a header-oriented library, make sure that the Eigen code
681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // is the same as the code used by the CPU backend (otherwise the linker will
691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // randomly pick *some* definition).
701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  __xla_cpu_runtime_EigenSingleThreadedMatMulF64(
711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      /*run_options_ptr=*/nullptr, result->data(), rhs.data(), lhs.data(), n, m,
721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      k,
731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      /*transpose_lhs=*/0,
741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      /*transpose_rhs=*/0);
751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return result;
761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<double>> ReferenceUtil::Array2DF32ToF64(
791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array2D<float>& input) {
801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto result = MakeUnique<Array2D<double>>(input.height(), input.width());
811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  for (int64 rowno = 0; rowno < input.height(); ++rowno) {
821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    for (int64 colno = 0; colno < input.height(); ++colno) {
831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      (*result)(rowno, colno) = input(rowno, colno);
841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return result;
871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::ConvArray4D(
901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array4D<float>& lhs, const Array4D<float>& rhs,
911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    std::pair<int64, int64> kernel_stride, Padding padding) {
921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return ConvArray4DGeneralDimensions(
931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      lhs, rhs, kernel_stride, padding,
941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ComputationBuilder::CreateDefaultConvDimensionNumbers());
951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
977fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee/* static */ std::unique_ptr<Array4D<float>>
987fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong LeeReferenceUtil::SeparableConvArray4D(const Array4D<float>& input,
997fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee                                    const Array4D<float>& depthwise_weights,
1007fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee                                    const Array4D<float>& pointwise_weights,
1017fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee                                    std::pair<int64, int64> kernel_stride,
1027fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee                                    Padding padding) {
1037fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee  const int64 depth_multiplier = depthwise_weights.planes();
1047fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee  CHECK_EQ(pointwise_weights.depth(), input.depth() * depth_multiplier);
1057fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee
1067fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee  // Combine the two weights by reducing the depth_multiplier, so that we can
1077fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee  // apply a single convolution on the combined weights.
1087fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee  Array4D<float> weights(pointwise_weights.planes(), input.depth(),
1097fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee                         depthwise_weights.height(), depthwise_weights.width());
1107fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee  for (int64 kx = 0; kx < depthwise_weights.width(); ++kx) {
1117fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee    for (int64 ky = 0; ky < depthwise_weights.height(); ++ky) {
1127fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee      for (int64 kz = 0; kz < input.depth(); ++kz) {
1137fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee        for (int64 out = 0; out < pointwise_weights.planes(); ++out) {
1147fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee          float weight = 0.0;
1157fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee          for (int64 depth = 0; depth < depth_multiplier; ++depth) {
1167fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee            weight +=
1177fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee                depthwise_weights(depth, kz, ky, kx) *
1187fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee                pointwise_weights(out, depth + kz * depth_multiplier, 0, 0);
1197fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee          }
1207fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee          weights(out, kz, ky, kx) = weight;
1217fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee        }
1227fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee      }
1237fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee    }
1247fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee  }
1257fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee
1267fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee  return ConvArray4D(input, weights, kernel_stride, padding);
1277fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee}
1287fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee
1291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ int64 ReferenceUtil::WindowCount(int64 unpadded_width,
1301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                              int64 window_len, int64 stride,
1311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                              Padding padding) {
1321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  if (padding == Padding::kValid) {
1331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    return window_util::StridedBound(unpadded_width, window_len, stride);
1341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
1351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return tensorflow::MathUtil::CeilOfRatio(unpadded_width, stride);
1361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
1371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1380034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi/* static  */ std::unique_ptr<std::vector<float>>
1390034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo OguntebiReferenceUtil::ReduceWindow1DGeneric(
1400034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi    const tensorflow::gtl::ArraySlice<float>& operand, float init,
1410034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi    const std::function<float(float, float)>& reduce_func,
1420034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi    const tensorflow::gtl::ArraySlice<int64>& window,
1430034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi    const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding) {
1440034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi  std::vector<int64> dim_lengths{static_cast<int64>(operand.size())};
1450034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi  auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding);
1460034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi
1470034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi  std::vector<int64> window_counts(window.size(), 0);
1480034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi  std::vector<int64> pad_low(window.size(), 0);
1490034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi  for (int64 i = 0; i < window.size(); ++i) {
1500034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi    window_counts[i] =
1510034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi        WindowCount(dim_lengths[i], window[i], stride[i], padding);
1520034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi    pad_low[i] = padding_both[i].first;
1530034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi  }
1540034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi  auto result = MakeUnique<std::vector<float>>(window_counts[0]);
1550034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi
1560034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi  // Do a full 1D reduce window.
1570034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi  for (int64 i0 = 0; i0 < window_counts[0]; ++i0) {
1580034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi    int64 i0_base = i0 * stride[0] - pad_low[0];
1590034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi
1600034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi    float val = init;
1610034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi    for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) {
1620034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi      if (i0_base + i0_win >= 0 && i0_base + i0_win < dim_lengths[0]) {
1630034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi        val = reduce_func(val, operand[i0_base + i0_win]);
1640034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi      }
1650034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi    }
1660034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi    (*result)[i0] = val;
1670034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi  }
1680034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi  return result;
1690034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi}
1700034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi
1710034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi/* static  */ std::unique_ptr<std::vector<float>>
1720034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo OguntebiReferenceUtil::ReduceWindow1DAdd(
1730034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi    const tensorflow::gtl::ArraySlice<float>& operand, float init,
1740034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi    const tensorflow::gtl::ArraySlice<int64>& window,
1750034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi    const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding) {
1760034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi  const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; };
1770034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi  return ReduceWindow1DGeneric(operand, init, add_reduce, window, stride,
1780034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi                               padding);
1790034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi}
1800034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi
1816bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi/* static  */ std::unique_ptr<Array2D<float>> ReferenceUtil::ReduceWindow2DAdd(
1826bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi    const Array2D<float>& operand, float init,
1836bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi    const tensorflow::gtl::ArraySlice<int64>& window,
1846bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi    const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding) {
1856bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi  std::vector<int64> dim_lengths{operand.height(), operand.width()};
1866bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi  auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding);
1876bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi
1886bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi  std::vector<int64> window_counts(window.size(), 0);
1896bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi  std::vector<int64> pad_low(window.size(), 0);
1906bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi  for (int64 i = 0; i < window.size(); ++i) {
1916bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi    window_counts[i] =
1926bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi        WindowCount(dim_lengths[i], window[i], stride[i], padding);
1936bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi    pad_low[i] = padding_both[i].first;
1946bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi  }
1956bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi  auto result = MakeUnique<Array2D<float>>(window_counts[0], window_counts[1]);
1966bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi
1976bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi  // Do a full 2D reduce window.
1986bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi  for (int64 i0 = 0; i0 < window_counts[0]; ++i0) {
1996bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi    for (int64 i1 = 0; i1 < window_counts[1]; ++i1) {
2006bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi      int64 i0_base = i0 * stride[0] - pad_low[0];
2016bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi      int64 i1_base = i1 * stride[1] - pad_low[1];
2026bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi
2036bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi      float val = init;
2046bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi      for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) {
2056bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi        for (int64 i1_win = 0; i1_win < window[1]; ++i1_win) {
2066bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi          if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 &&
2076bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi              i0_base + i0_win < operand.n1() &&
2086bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi              i1_base + i1_win < operand.n2()) {
2096bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi            val += operand(i0_base + i0_win, i1_base + i1_win);
2106bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi          }
2116bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi        }
2126bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi      }
2136bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi      (*result)(i0, i1) = val;
2146bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi    }
2156bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi  }
2166bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi  return result;
2176bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi}
2186bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi
2192d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi/* static */ std::unique_ptr<Array4D<float>>
2202d69270342d2a5e46446e02e9273e7da79f00accTayo OguntebiReferenceUtil::ReduceWindow4DGeneric(
2211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array4D<float>& operand, float init,
2222d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi    const std::function<float(float, float)>& reduce_func,
2231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const tensorflow::gtl::ArraySlice<int64>& window,
2241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding) {
2251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(),
2261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                 operand.n4()};
227f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune  return ReduceWindow4DGeneric(
228f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune      operand, init, reduce_func, window, stride,
229f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune      xla::MakePadding(dim_lengths, window, stride, padding));
230f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune}
231f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune
232f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune/* static */ std::unique_ptr<Array4D<float>>
233f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt RouneReferenceUtil::ReduceWindow4DGeneric(
234f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune    const Array4D<float>& operand, float init,
235f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune    const std::function<float(float, float)>& reduce_func,
236f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune    const tensorflow::gtl::ArraySlice<int64>& window,
237f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune    const tensorflow::gtl::ArraySlice<int64>& stride,
238f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune    const tensorflow::gtl::ArraySlice<std::pair<int64, int64>>& padding) {
239f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune  std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(),
240f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune                                 operand.n4()};
2411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
2421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  std::vector<int64> window_counts(window.size(), 0);
2431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  std::vector<int64> pad_low(window.size(), 0);
2441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  for (int64 i = 0; i < window.size(); ++i) {
245f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune    int64 padded_width = padding[i].first + dim_lengths[i] + padding[i].second;
2461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    window_counts[i] =
247f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune        window_util::StridedBound(padded_width, window[i], stride[i]);
248f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune    pad_low[i] = padding[i].first;
2491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
2501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto result = MakeUnique<Array4D<float>>(window_counts[0], window_counts[1],
2511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                           window_counts[2], window_counts[3]);
2521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // Do a full 4D reduce window.
2531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  for (int64 i0 = 0; i0 < window_counts[0]; ++i0) {
2541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    for (int64 i1 = 0; i1 < window_counts[1]; ++i1) {
2551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      for (int64 i2 = 0; i2 < window_counts[2]; ++i2) {
2561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        for (int64 i3 = 0; i3 < window_counts[3]; ++i3) {
2571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          int64 i0_base = i0 * stride[0] - pad_low[0];
2581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          int64 i1_base = i1 * stride[1] - pad_low[1];
2591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          int64 i2_base = i2 * stride[2] - pad_low[2];
2601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          int64 i3_base = i3 * stride[3] - pad_low[3];
2611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
2621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          float val = init;
2631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) {
2641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins            for (int64 i1_win = 0; i1_win < window[1]; ++i1_win) {
2651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins              for (int64 i2_win = 0; i2_win < window[2]; ++i2_win) {
2661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                for (int64 i3_win = 0; i3_win < window[3]; ++i3_win) {
2671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                  if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 &&
2681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      i2_base + i2_win >= 0 && i3_base + i3_win >= 0 &&
2691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      i0_base + i0_win < operand.n1() &&
2701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      i1_base + i1_win < operand.n2() &&
2711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      i2_base + i2_win < operand.n3() &&
2721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      i3_base + i3_win < operand.n4()) {
2732d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi                    val = reduce_func(
2742d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi                        val, operand(i0_base + i0_win, i1_base + i1_win,
2752d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi                                     i2_base + i2_win, i3_base + i3_win));
2761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                  }
2771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                }
2781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins              }
2791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins            }
2801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          }
2811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          (*result)(i0, i1, i2, i3) = val;
2821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        }
2831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      }
2841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
2851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
2861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return result;
2871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
2881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
2892d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi/* static  */ std::unique_ptr<Array4D<float>> ReferenceUtil::ReduceWindow4DAdd(
2902d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi    const Array4D<float>& operand, float init,
2912d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi    const tensorflow::gtl::ArraySlice<int64>& window,
2922d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi    const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding) {
2932d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi  const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; };
2942d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi  return ReduceWindow4DGeneric(operand, init, add_reduce, window, stride,
2952d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi                               padding);
2962d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi}
2972d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi
2981464b9930de871fd11870941963253670f737c23A. Unique TensorFlower/* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::BatchNorm4D(
2991464b9930de871fd11870941963253670f737c23A. Unique TensorFlower    const Array4D<float>& input, const Array4D<float>& mean,
3001464b9930de871fd11870941963253670f737c23A. Unique TensorFlower    const Array4D<float>& var, const Array4D<float>& scale,
3011464b9930de871fd11870941963253670f737c23A. Unique TensorFlower    const Array4D<float>& offset, float epsilon) {
3021464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto normalized =
3031464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      *MapArray4D(input, mean, [](float a, float b) { return a - b; });
3041464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  normalized = *MapArray4D(normalized, var, [&](float a, float b) {
3051464b9930de871fd11870941963253670f737c23A. Unique TensorFlower    return a / std::sqrt(b + epsilon);
3061464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  });
3071464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  normalized =
3081464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      *MapArray4D(normalized, scale, [](float a, float b) { return a * b; });
3091464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  return MapArray4D(normalized, offset, [](float a, float b) { return a + b; });
3101464b9930de871fd11870941963253670f737c23A. Unique TensorFlower}
3111464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
3121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static  */ std::unique_ptr<Array4D<float>>
3131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsReferenceUtil::SelectAndScatter4DGePlus(
3141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array4D<float>& operand, const Array4D<float>& source, float init,
3151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const tensorflow::gtl::ArraySlice<int64>& window,
3161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const tensorflow::gtl::ArraySlice<int64>& stride, bool same_padding) {
3171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  Padding padding = same_padding ? Padding::kSame : Padding::kValid;
3181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto result = MakeUnique<Array4D<float>>(operand.n1(), operand.n2(),
3191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                           operand.n3(), operand.n4());
3201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(),
3211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                 operand.n4()};
3221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding);
3231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // Fill the output, with the initial value.
3241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  result->Fill(init);
3251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
3261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  std::vector<int64> window_counts(window.size(), 0);
3271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  std::vector<int64> pad_low(window.size(), 0);
3281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  for (int64 i = 0; i < window.size(); ++i) {
3291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    window_counts[i] =
3301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        WindowCount(dim_lengths[i], window[i], stride[i], padding);
3311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    pad_low[i] = padding_both[i].first;
3321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
3331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  CHECK_EQ(window_counts[0], source.n1());
3341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  CHECK_EQ(window_counts[1], source.n2());
3351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  CHECK_EQ(window_counts[2], source.n3());
3361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  CHECK_EQ(window_counts[3], source.n4());
3371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
3381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // Do a full 4D select and Scatter.
3391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  for (int64 i0 = 0; i0 < window_counts[0]; ++i0) {
3401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    for (int64 i1 = 0; i1 < window_counts[1]; ++i1) {
3411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      for (int64 i2 = 0; i2 < window_counts[2]; ++i2) {
3421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        for (int64 i3 = 0; i3 < window_counts[3]; ++i3) {
3431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          // Now we are inside a window and need to find the max and the argmax.
3441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          int64 i0_base = i0 * stride[0] - pad_low[0];
3451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          int64 i1_base = i1 * stride[1] - pad_low[1];
3461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          int64 i2_base = i2 * stride[2] - pad_low[2];
3471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          int64 i3_base = i3 * stride[3] - pad_low[3];
3481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          int64 scatter_0 = (i0_base >= 0) ? i0_base : 0;
3491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          int64 scatter_1 = (i1_base >= 0) ? i1_base : 0;
3501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          int64 scatter_2 = (i2_base >= 0) ? i2_base : 0;
3511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          int64 scatter_3 = (i3_base >= 0) ? i3_base : 0;
3521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          float val = operand(scatter_0, scatter_1, scatter_2, scatter_3);
3531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) {
3541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins            for (int64 i1_win = 0; i1_win < window[1]; ++i1_win) {
3551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins              for (int64 i2_win = 0; i2_win < window[2]; ++i2_win) {
3561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                for (int64 i3_win = 0; i3_win < window[3]; ++i3_win) {
3571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                  if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 &&
3581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      i2_base + i2_win >= 0 && i3_base + i3_win >= 0 &&
3591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      i0_base + i0_win < operand.n1() &&
3601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      i1_base + i1_win < operand.n2() &&
3611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      i2_base + i2_win < operand.n3() &&
3621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      i3_base + i3_win < operand.n4()) {
3631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                    float tmp = operand(i0_base + i0_win, i1_base + i1_win,
3641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                        i2_base + i2_win, i3_base + i3_win);
3651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                    if (tmp >= val) {
3661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      val = tmp;
3671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      scatter_0 = i0_base + i0_win;
3681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      scatter_1 = i1_base + i1_win;
3691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      scatter_2 = i2_base + i2_win;
3701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      scatter_3 = i3_base + i3_win;
3711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                    }
3721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                  }
3731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                }
3741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins              }
3751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins            }
3761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          }
3771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          (*result)(scatter_0, scatter_1, scatter_2, scatter_3) +=
3781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins              source(i0, i1, i2, i3);
3791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        }
3801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      }
3811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
3821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
3831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return result;
3841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
3851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
3861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array4D<float>>
3871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsReferenceUtil::ConvArray4DGeneralDimensions(
3881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array4D<float>& lhs, const Array4D<float>& rhs,
3891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    std::pair<int64, int64> kernel_stride, Padding padding,
3901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    ConvolutionDimensionNumbers dimension_numbers) {
3911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return ConvArray4DGeneralDimensionsDilated(lhs, rhs, kernel_stride, padding,
392f4b8d21b8e41636b6e61f0a1de753430108d2ee7A. Unique TensorFlower                                             {1, 1}, {1, 1},
393f4b8d21b8e41636b6e61f0a1de753430108d2ee7A. Unique TensorFlower                                             std::move(dimension_numbers));
3941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
3951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
3961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array4D<float>>
3971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsReferenceUtil::ConvArray4DGeneralDimensionsDilated(
3981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array4D<float>& lhs, const Array4D<float>& rhs,
3991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    std::pair<int64, int64> kernel_stride, Padding padding,
4001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    std::pair<int64, int64> lhs_dilation, std::pair<int64, int64> rhs_dilation,
4011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    ConvolutionDimensionNumbers dnums) {
4021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  std::array<int64, 4> lhs_dimensions{{lhs.n1(), lhs.n2(), lhs.n3(), lhs.n4()}};
4031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  std::array<int64, 4> rhs_dimensions{{rhs.n1(), rhs.n2(), rhs.n3(), rhs.n4()}};
4041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
4051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  const int64 ksy = kernel_stride.first;
4061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  const int64 ksx = kernel_stride.second;
4071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  const int64 dy = lhs_dilation.first;
4081e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  const int64 dx = lhs_dilation.second;
4091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  const int64 dky = rhs_dilation.first;
4101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  const int64 dkx = rhs_dilation.second;
4111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  CHECK_GE(dky, 1);
4121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  CHECK_GE(dkx, 1);
4131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  CHECK_GE(dy, 1);
4141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  CHECK_GE(dx, 1);
4151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
4161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // Get all dimension sizes in lhs and rhs based on the given convolution
4171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // dimension configuration.
4181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  const int64 ix = window_util::DilatedBound(
4191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      lhs_dimensions[dnums.spatial_dimensions(1)], dx);
4201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  const int64 iy = window_util::DilatedBound(
4211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      lhs_dimensions[dnums.spatial_dimensions(0)], dy);
4221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  const int64 iz = lhs_dimensions[dnums.feature_dimension()];
4231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  const int64 samples = lhs_dimensions[dnums.batch_dimension()];
4241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  const int64 kx = window_util::DilatedBound(
4251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      rhs_dimensions[dnums.kernel_spatial_dimensions(1)], dkx);
4261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  const int64 ky = window_util::DilatedBound(
4271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      rhs_dimensions[dnums.kernel_spatial_dimensions(0)], dky);
4281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  const int64 oz = rhs_dimensions[dnums.kernel_output_feature_dimension()];
4291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  {
4301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const int64 kiz = rhs_dimensions[dnums.kernel_input_feature_dimension()];
4311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    CHECK_EQ(kiz, iz);
4321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
4331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
4341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  if (padding == Padding::kSame) {
4351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    // We reject same padding with kernel striding, since it's somewhat
4361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    // nonsensical. We can always follow up to implement this with the desired
4371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    // semantics if anybody actually uses it.
4381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    CHECK_EQ(1, ksy);
4391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    CHECK_EQ(1, ksx);
4401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
4411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
4421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  const int64 ox =
4431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      padding == Padding::kSame ? ix : window_util::StridedBound(ix, kx, ksx);
4441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  const int64 oy =
4451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      padding == Padding::kSame ? iy : window_util::StridedBound(iy, ky, ksy);
4461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  const int64 istartx =
4471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      padding == Padding::kValid ? 0 : kx % 2 == 0 ? -(kx / 2 - 1) : -kx / 2;
4481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  const int64 istarty =
4491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      padding == Padding::kValid ? 0 : ky % 2 == 0 ? -(ky / 2 - 1) : -ky / 2;
4501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // Create the output result array and reset the values to 0.
4511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  std::array<int64, 4> result_dimensions;
4521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  result_dimensions[dnums.batch_dimension()] = samples;
4531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  result_dimensions[dnums.feature_dimension()] = oz;
4541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  result_dimensions[dnums.spatial_dimensions(0)] = oy;
4551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  result_dimensions[dnums.spatial_dimensions(1)] = ox;
4561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto result =
4571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      MakeUnique<Array4D<float>>(result_dimensions[0], result_dimensions[1],
4581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                 result_dimensions[2], result_dimensions[3]);
4591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  result->Fill(0.0);
4601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
4617135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar  const auto is_int32 = [](int64 x) {
4627135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar    return x >= std::numeric_limits<int32>::min() &&
4637135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar           x <= std::numeric_limits<int32>::max();
4647135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar  };
4657135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar
4667135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar  // 64-bit idiv/mod are much more expensive x86-64 than 32-bit idiv/imod (at
4677135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar  // least on x86-64), so we avoid them where possible.
4687135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar  const auto fast_idiv64 = [&](int64 a, int64 b) {
4697135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar    if (is_int32(a) && is_int32(b)) {
4707135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar      return static_cast<int64>(static_cast<int32>(a) / static_cast<int32>(b));
4717135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar    }
4727135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar    return a / b;
4737135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar  };
4747135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar  const auto fast_imod64 = [&](int64 a, int64 b) {
4757135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar    if (is_int32(a) && is_int32(b)) {
4767135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar      return static_cast<int64>(static_cast<int32>(a) % static_cast<int32>(b));
4777135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar    }
4787135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar    return a % b;
4797135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar  };
4807135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar
4811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // Lambda to access the lhs operand at the given 4D index.
4821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  const auto lhs_element = [&](int64 batch, int64 feature, int64 height,
4831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                               int64 width) {
4847135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar    if (fast_imod64(height, dy) != 0 || fast_imod64(width, dx) != 0) {
4851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      return 0.0f;
4861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
4871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
4881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    std::array<int64, 4> index;
4891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    index[dnums.batch_dimension()] = batch;
4901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    index[dnums.feature_dimension()] = feature;
4917135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar    index[dnums.spatial_dimensions(0)] = fast_idiv64(height, dy);
4927135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar    index[dnums.spatial_dimensions(1)] = fast_idiv64(width, dx);
4931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    return lhs(index[0], index[1], index[2], index[3]);
4941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  };
4951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
496f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar  // Lambda to access the rhs operand at the given 4D index.  height_over_dky
497f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar  // should be equal to height / dky, and width_over_dkx should be equal to
498f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar  // width / dkx.  (This is an optimization to avoid doing divisions.)
4990034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi  const auto rhs_element =
5000034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi      [&](int64 kernel_output_feature, int64 kernel_input_feature, int64 height,
5010034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi          int64 width, int64 height_over_dky, int64 width_over_dkx) {
5020034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi        DCHECK_EQ(height % dky, 0);
5030034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi        DCHECK_EQ(width % dkx, 0);
5040034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi        DCHECK_EQ(height / dky, height_over_dky);
5050034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi        DCHECK_EQ(width / dkx, width_over_dkx);
5060034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi
5070034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi        std::array<int64, 4> index;
5080034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi        index[dnums.kernel_output_feature_dimension()] = kernel_output_feature;
5090034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi        index[dnums.kernel_input_feature_dimension()] = kernel_input_feature;
5100034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi        index[dnums.kernel_spatial_dimensions(0)] = height_over_dky;
5110034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi        index[dnums.kernel_spatial_dimensions(1)] = width_over_dkx;
5120034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi        return rhs(index[0], index[1], index[2], index[3]);
5130034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi      };
5141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
5151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // Lambda to access the result data at the given 4D index.
5161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  const auto result_element = [&](int64 batch, int64 kernel_output_feature,
5171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                  int64 height, int64 width) -> float& {
5181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    std::array<int64, 4> index;
5191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    index[dnums.batch_dimension()] = batch;
5201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    index[dnums.feature_dimension()] = kernel_output_feature;
5211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    index[dnums.spatial_dimensions(0)] = height;
5221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    index[dnums.spatial_dimensions(1)] = width;
5231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    return (*result)(index[0], index[1], index[2], index[3]);
5241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  };
5251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
5261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  for (int64 oyi = 0; oyi < oy; ++oyi) {
5271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    for (int64 oxi = 0; oxi < ox; ++oxi) {
5281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      for (int64 sample = 0; sample < samples; ++sample) {
5291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        for (int64 izi = 0; izi < iz; ++izi) {
5301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          for (int64 ozi = 0; ozi < oz; ++ozi) {
531f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar            for (int64 kyi = 0, kyi_over_dky = 0; kyi < ky;
532f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar                 kyi += dky, kyi_over_dky++) {
533f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar              for (int64 kxi = 0, kxi_over_dkx = 0; kxi < kx;
534f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar                   kxi += dkx, kxi_over_dkx++) {
5351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                int64 iyi = istarty + ksy * oyi + kyi;
5361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                int64 ixi = istartx + ksx * oxi + kxi;
5371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                float input = (iyi >= iy || ixi >= ix || iyi < 0 || ixi < 0)
5381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                  ? 0.0
5391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                  : lhs_element(sample, izi, iyi, ixi);
540f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar                float gain =
541f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar                    rhs_element(ozi, izi, kyi, kxi, kyi_over_dky, kxi_over_dkx);
5421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                float addend = input * gain;
5431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                result_element(sample, ozi, oyi, oxi) += addend;
5441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins              }
5451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins            }
5461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          }
5471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        }
5481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      }
5491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
5501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
5511c57e8864aeb8093f12e8f390571fb8cd9f376a4A. Unique TensorFlower  if (samples == 0 || kx == 0 || ky == 0 || ox == 0 || oy == 0 || oz == 0 ||
5521c57e8864aeb8093f12e8f390571fb8cd9f376a4A. Unique TensorFlower      iz == 0) {
5531c57e8864aeb8093f12e8f390571fb8cd9f376a4A. Unique TensorFlower    LOG(INFO) << "Output will be trivially empty because one of these "
5541c57e8864aeb8093f12e8f390571fb8cd9f376a4A. Unique TensorFlower                 "dimensions is 0: samples: "
5551c57e8864aeb8093f12e8f390571fb8cd9f376a4A. Unique TensorFlower              << samples << " kx: " << kx << " ky: " << ky << " ox: " << ox
5561c57e8864aeb8093f12e8f390571fb8cd9f376a4A. Unique TensorFlower              << " oy: " << oy << " oz: " << oz << " iz: " << iz;
5571c57e8864aeb8093f12e8f390571fb8cd9f376a4A. Unique TensorFlower    return result;
5581c57e8864aeb8093f12e8f390571fb8cd9f376a4A. Unique TensorFlower  }
5591c57e8864aeb8093f12e8f390571fb8cd9f376a4A. Unique TensorFlower  bool trivial = true;
5601c57e8864aeb8093f12e8f390571fb8cd9f376a4A. Unique TensorFlower  auto check_trivial = [&trivial](tensorflow::gtl::ArraySlice<int64> indices,
5611c57e8864aeb8093f12e8f390571fb8cd9f376a4A. Unique TensorFlower                                  float value) {
5621c57e8864aeb8093f12e8f390571fb8cd9f376a4A. Unique TensorFlower    if (value != 0.0) {
5631c57e8864aeb8093f12e8f390571fb8cd9f376a4A. Unique TensorFlower      trivial = false;
5641c57e8864aeb8093f12e8f390571fb8cd9f376a4A. Unique TensorFlower    }
5651c57e8864aeb8093f12e8f390571fb8cd9f376a4A. Unique TensorFlower  };
5661c57e8864aeb8093f12e8f390571fb8cd9f376a4A. Unique TensorFlower  lhs.Each(check_trivial);
567beb291b3c0e95f1150c446e06f8f80675ab3b528Peter Hawkins  if (trivial) {
568beb291b3c0e95f1150c446e06f8f80675ab3b528Peter Hawkins    LOG(FATAL) << "LHS is all 0.0.";
569beb291b3c0e95f1150c446e06f8f80675ab3b528Peter Hawkins  }
5701c57e8864aeb8093f12e8f390571fb8cd9f376a4A. Unique TensorFlower  trivial = true;
5711c57e8864aeb8093f12e8f390571fb8cd9f376a4A. Unique TensorFlower  rhs.Each(check_trivial);
572beb291b3c0e95f1150c446e06f8f80675ab3b528Peter Hawkins  if (trivial) {
573beb291b3c0e95f1150c446e06f8f80675ab3b528Peter Hawkins    LOG(FATAL) << "RHS is all 0.0.";
574beb291b3c0e95f1150c446e06f8f80675ab3b528Peter Hawkins  }
5751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return result;
5761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
5771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
5781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<std::vector<float>>
5791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsReferenceUtil::ReduceToColArray2D(
5801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array2D<float>& matrix, float init,
5811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    std::function<float(float, float)> reduce_function) {
5821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 rows = matrix.height();
5831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 cols = matrix.width();
5841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto result = MakeUnique<std::vector<float>>();
5851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  for (int64 i = 0; i < rows; ++i) {
5861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    float acc = init;
5871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    for (int64 j = 0; j < cols; ++j) {
5881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      acc = reduce_function(acc, matrix(i, j));
5891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
5901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    result->push_back(acc);
5911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
5921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return result;
5931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
5941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
5951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<std::vector<float>>
5961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsReferenceUtil::ReduceToRowArray2D(
5971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array2D<float>& matrix, float init,
5981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    std::function<float(float, float)> reduce_function) {
5991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 rows = matrix.height();
6001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 cols = matrix.width();
6011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto result = MakeUnique<std::vector<float>>();
6021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  for (int64 i = 0; i < cols; ++i) {
6031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    float acc = init;
6041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    for (int64 j = 0; j < rows; ++j) {
6051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      acc = reduce_function(acc, matrix(j, i));
6061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
6071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    result->push_back(acc);
6081e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
6091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return result;
6101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
6111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
6121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/*static*/ std::vector<float> ReferenceUtil::Reduce4DTo1D(
6131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array4D<float>& array, float init,
6141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    tensorflow::gtl::ArraySlice<int64> dims,
6151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    std::function<float(float, float)> reduce_function) {
6161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  std::vector<float> result;
6171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  CHECK_EQ(dims.size(), 3);
6181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  const std::set<int64> dim_set(dims.begin(), dims.end());
6191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  CHECK_EQ(dim_set.size(), 3);
6201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  for (int64 a0 = 0; a0 == 0 || (!dim_set.count(0) && a0 < array.n1()); ++a0) {
6211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    for (int64 a1 = 0; a1 == 0 || (!dim_set.count(1) && a1 < array.n2());
6221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins         ++a1) {
6231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      for (int64 a2 = 0; a2 == 0 || (!dim_set.count(2) && a2 < array.n3());
6241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins           ++a2) {
6251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        for (int64 a3 = 0; a3 == 0 || (!dim_set.count(3) && a3 < array.n4());
6261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins             ++a3) {
6271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          float accumulator = init;
6281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          for (int64 i0 = 0; i0 == 0 || (dim_set.count(0) && i0 < array.n1());
6291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins               ++i0) {
6301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins            for (int64 i1 = 0; i1 == 0 || (dim_set.count(1) && i1 < array.n2());
6311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                 ++i1) {
6321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins              for (int64 i2 = 0;
6331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                   i2 == 0 || (dim_set.count(2) && i2 < array.n3()); ++i2) {
6341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                for (int64 i3 = 0;
6351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                     i3 == 0 || (dim_set.count(3) && i3 < array.n4()); ++i3) {
6361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                  accumulator = reduce_function(
6371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      accumulator, array(a0 + i0, a1 + i1, a2 + i2, a3 + i3));
6381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                }
6391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins              }
6401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins            }
6411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          }
6421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          result.push_back(accumulator);
6431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        }
6441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      }
6451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
6461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
6471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return result;
6481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
6491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
6501464b9930de871fd11870941963253670f737c23A. Unique TensorFlower/* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::Broadcast1DTo4D(
6511464b9930de871fd11870941963253670f737c23A. Unique TensorFlower    const std::vector<float>& array, const std::vector<int64>& bounds,
6521464b9930de871fd11870941963253670f737c23A. Unique TensorFlower    int64 broadcast_from_dim) {
6531464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto result =
6541464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      MakeUnique<Array4D<float>>(bounds[0], bounds[1], bounds[2], bounds[3]);
6551464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  for (int64 i = 0; i < result->n1(); ++i) {
6561464b9930de871fd11870941963253670f737c23A. Unique TensorFlower    for (int64 j = 0; j < result->n2(); ++j) {
6571464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      for (int64 k = 0; k < result->n3(); ++k) {
6581464b9930de871fd11870941963253670f737c23A. Unique TensorFlower        for (int64 l = 0; l < result->n4(); ++l) {
6591464b9930de871fd11870941963253670f737c23A. Unique TensorFlower          switch (broadcast_from_dim) {
6601464b9930de871fd11870941963253670f737c23A. Unique TensorFlower            case 0:
6611464b9930de871fd11870941963253670f737c23A. Unique TensorFlower              (*result)(i, j, k, l) = array[i];
6621464b9930de871fd11870941963253670f737c23A. Unique TensorFlower              break;
6631464b9930de871fd11870941963253670f737c23A. Unique TensorFlower            case 1:
6641464b9930de871fd11870941963253670f737c23A. Unique TensorFlower              (*result)(i, j, k, l) = array[j];
6651464b9930de871fd11870941963253670f737c23A. Unique TensorFlower              break;
6661464b9930de871fd11870941963253670f737c23A. Unique TensorFlower            case 2:
6671464b9930de871fd11870941963253670f737c23A. Unique TensorFlower              (*result)(i, j, k, l) = array[k];
6681464b9930de871fd11870941963253670f737c23A. Unique TensorFlower              break;
6691464b9930de871fd11870941963253670f737c23A. Unique TensorFlower            case 3:
6701464b9930de871fd11870941963253670f737c23A. Unique TensorFlower              (*result)(i, j, k, l) = array[l];
6711464b9930de871fd11870941963253670f737c23A. Unique TensorFlower              break;
6721464b9930de871fd11870941963253670f737c23A. Unique TensorFlower            default:
6731464b9930de871fd11870941963253670f737c23A. Unique TensorFlower              break;
6741464b9930de871fd11870941963253670f737c23A. Unique TensorFlower          }
6751464b9930de871fd11870941963253670f737c23A. Unique TensorFlower        }
6761464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      }
6771464b9930de871fd11870941963253670f737c23A. Unique TensorFlower    }
6781464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  }
6791464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  return result;
6801464b9930de871fd11870941963253670f737c23A. Unique TensorFlower}
6811464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
6821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::Reduce3DTo2D(
6831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array3D<float>& array, float init,
6841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    tensorflow::gtl::ArraySlice<int64> dims,
6851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    std::function<float(float, float)> reduce_function) {
6861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  CHECK_EQ(dims.size(), 1);
6871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 rows = dims[0] == 0 ? array.n2() : array.n1();
6881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 cols = dims[0] == 2 ? array.n2() : array.n3();
6891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto result = MakeUnique<Array2D<float>>(rows, cols);
6901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  result->Fill(init);
6911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  for (int i0 = 0; i0 < array.n1(); ++i0) {
6921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    for (int i1 = 0; i1 < array.n2(); ++i1) {
6931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      for (int i2 = 0; i2 < array.n3(); ++i2) {
6941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        int64 row = dims[0] == 0 ? i1 : i0;
6951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        int64 col = dims[0] == 2 ? i1 : i2;
6961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        (*result)(row, col) =
6971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins            reduce_function((*result)(row, col), array(i0, i1, i2));
6981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      }
6991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
7001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
7011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return result;
7021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
7031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
7041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MapArray2D(
7051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array2D<float>& matrix,
7061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const std::function<float(float)>& map_function) {
7071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 rows = matrix.height();
7081e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 cols = matrix.width();
7091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto result = MakeUnique<Array2D<float>>(rows, cols);
7101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  for (int64 i = 0; i < rows; ++i) {
7111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    for (int64 j = 0; j < cols; ++j) {
7121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      (*result)(i, j) = map_function(matrix(i, j));
7131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
7141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
7151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return result;
7161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
7171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
7181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MapArray2D(
7191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array2D<float>& lhs, const Array2D<float>& rhs,
7201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const std::function<float(float, float)>& map_function) {
7211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  CHECK_EQ(lhs.height(), rhs.height());
7221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  CHECK_EQ(lhs.width(), rhs.width());
7231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 rows = lhs.height();
7241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 cols = rhs.width();
7251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto result = MakeUnique<Array2D<float>>(rows, cols);
7261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  for (int64 i = 0; i < rows; ++i) {
7271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    for (int64 j = 0; j < cols; ++j) {
7281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      (*result)(i, j) = map_function(lhs(i, j), rhs(i, j));
7291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
7301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
7311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return result;
7321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
7331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
7341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MapWithIndexArray2D(
7351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array2D<float>& matrix,
7361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const std::function<float(float, int64, int64)>& map_function) {
7371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 rows = matrix.height();
7381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 cols = matrix.width();
7391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto result = MakeUnique<Array2D<float>>(rows, cols);
7401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  for (int64 i = 0; i < rows; ++i) {
7411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    for (int64 j = 0; j < cols; ++j) {
7421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      (*result)(i, j) = map_function(matrix(i, j), i, j);
7431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
7441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
7451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return result;
7461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
7471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
7481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::PadArray2D(
7491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array2D<float>& operand, const PaddingConfig& padding,
7501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const float pad) {
7511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 in0 = operand.n1();
7521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 high_padding0 = padding.dimensions(0).edge_padding_high();
7531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 low_padding0 = padding.dimensions(0).edge_padding_low();
7541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 interior_padding0 = padding.dimensions(0).interior_padding();
7551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 out0 =
7561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      in0 + low_padding0 + high_padding0 + (in0 - 1) * interior_padding0;
7575aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan
7581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 in1 = operand.n2();
7591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 high_padding1 = padding.dimensions(1).edge_padding_high();
7601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 low_padding1 = padding.dimensions(1).edge_padding_low();
7611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 interior_padding1 = padding.dimensions(1).interior_padding();
7621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 out1 =
7631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      in1 + low_padding1 + high_padding1 + (in1 - 1) * interior_padding1;
7645aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan
7651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto result = MakeUnique<Array2D<float>>(out0, out1);
7661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  result->Fill(pad);
7675aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan  int64 o0 = low_padding0;
7685aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan  for (int64 i0 = 0; i0 < in0; ++i0) {
7695aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan    int64 o1 = low_padding1;
7705aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan    for (int64 i1 = 0; i1 < in1; ++i1) {
7715aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan      if (o0 >= 0 && o1 >= 0 && o0 < out0 && o1 < out1) {
7725aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan        (*result)(o0, o1) = operand(i0, i1);
7735aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan      }
7745aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan      o1 += interior_padding1 + 1;
7751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
7765aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan    o0 += interior_padding0 + 1;
7771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
7781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return result;
7791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
7801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
781c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune/* static */ Array4D<float> ReferenceUtil::PadArray4D(
782c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune    const Array4D<float>& operand, const PaddingConfig& padding,
783c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune    const float pad) {
784c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune  CHECK_EQ(padding.dimensions_size(), 4);
785c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune
786c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune  const std::vector<int64> input_bounds = {operand.n1(), operand.n2(),
787c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune                                           operand.n3(), operand.n4()};
788c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune  std::vector<int64> pad_low(4);
789c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune  std::vector<int64> pad_high(4);
7908ac9eb6318084fcb3d9dc71a7a7fc05bf18048d8A. Unique TensorFlower  std::vector<int64> pad_interior(4);
791c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune  std::vector<int64> output_bounds(4);
792c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune  for (int64 i = 0; i < 4; ++i) {
793c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune    pad_low[i] = padding.dimensions(i).edge_padding_low();
794c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune    pad_high[i] = padding.dimensions(i).edge_padding_high();
7958ac9eb6318084fcb3d9dc71a7a7fc05bf18048d8A. Unique TensorFlower    CHECK_LE(0, padding.dimensions(i).interior_padding()) << "not implemented";
7968ac9eb6318084fcb3d9dc71a7a7fc05bf18048d8A. Unique TensorFlower    pad_interior[i] = padding.dimensions(i).interior_padding();
797c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune
7988ac9eb6318084fcb3d9dc71a7a7fc05bf18048d8A. Unique TensorFlower    output_bounds[i] = pad_low[i] + input_bounds[i] + pad_high[i] +
7998ac9eb6318084fcb3d9dc71a7a7fc05bf18048d8A. Unique TensorFlower                       (input_bounds[i] - 1) * pad_interior[i];
800c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune  }
801c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune
802c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune  Array4D<float> result(output_bounds[0], output_bounds[1], output_bounds[2],
803c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune                        output_bounds[3]);
804c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune  result.Each([&](tensorflow::gtl::ArraySlice<int64> indices, float* value) {
805c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune    for (int i = 0; i < 4; ++i) {
806c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune      bool in_low_padding = indices[i] < pad_low[i];
807c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune      bool in_high_padding = indices[i] >= output_bounds[i] - pad_high[i];
808c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune      if (in_low_padding || in_high_padding) {
809c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune        *value = pad;
810c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune        return;
811c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune      }
8128ac9eb6318084fcb3d9dc71a7a7fc05bf18048d8A. Unique TensorFlower      if (pad_interior[i] &&
8138ac9eb6318084fcb3d9dc71a7a7fc05bf18048d8A. Unique TensorFlower          (indices[i] - pad_low[i]) % (pad_interior[i] + 1)) {
8148ac9eb6318084fcb3d9dc71a7a7fc05bf18048d8A. Unique TensorFlower        *value = pad;
8158ac9eb6318084fcb3d9dc71a7a7fc05bf18048d8A. Unique TensorFlower        return;
8168ac9eb6318084fcb3d9dc71a7a7fc05bf18048d8A. Unique TensorFlower      }
817c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune    }
8188ac9eb6318084fcb3d9dc71a7a7fc05bf18048d8A. Unique TensorFlower    *value = operand((indices[0] - pad_low[0]) / (pad_interior[0] + 1),
8198ac9eb6318084fcb3d9dc71a7a7fc05bf18048d8A. Unique TensorFlower                     (indices[1] - pad_low[1]) / (pad_interior[1] + 1),
8208ac9eb6318084fcb3d9dc71a7a7fc05bf18048d8A. Unique TensorFlower                     (indices[2] - pad_low[2]) / (pad_interior[2] + 1),
8218ac9eb6318084fcb3d9dc71a7a7fc05bf18048d8A. Unique TensorFlower                     (indices[3] - pad_low[3]) / (pad_interior[3] + 1));
822c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune  });
823c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune  return result;
824c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune}
825c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune
8261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}  // namespace xla
827