reference_util.cc revision 6bbbd7e9d2016dfd201797d1f1354ccc48bd9e13
11e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
21e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
31e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsLicensed under the Apache License, Version 2.0 (the "License");
41e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsyou may not use this file except in compliance with the License.
51e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsYou may obtain a copy of the License at
61e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
71e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    http://www.apache.org/licenses/LICENSE-2.0
81e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
91e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsUnless required by applicable law or agreed to in writing, software
101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsdistributed under the License is distributed on an "AS IS" BASIS,
111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsSee the License for the specific language governing permissions and
131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinslimitations under the License.
141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins==============================================================================*/
151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/reference_util.h"
171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include <array>
191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/client/computation_builder.h"
211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h"
221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/window_util.h"
231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/xla_data.pb.h"
241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/lib/math/math_util.h"
251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/platform/logging.h"
261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsnamespace xla {
281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::TransposeArray2D(
301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array2D<float>& operand) {
311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto result = MakeUnique<Array2D<float>>(operand.width(), operand.height());
321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  for (int64 w = 0; w < operand.width(); ++w) {
331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    for (int64 h = 0; h < operand.height(); ++h) {
341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      (*result)(w, h) = operand(h, w);
351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return result;
391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MatmulArray2D(
421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array2D<float>& lhs, const Array2D<float>& rhs) {
431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  CHECK_EQ(lhs.width(), rhs.height());
441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int m = lhs.height();
451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int n = rhs.width();
461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int k = lhs.width();
471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto result = MakeUnique<Array2D<float>>(m, n);
481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // Because Eigen is a header-oriented library, make sure that the Eigen code
491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // is the same as the code used by the CPU backend (otherwise the linker will
501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // randomly pick *some* definition).
511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  __xla_cpu_runtime_EigenSingleThreadedMatMulF32(
521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      /*run_options_ptr=*/nullptr, result->data(), rhs.data(), lhs.data(), n, m,
531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      k,
541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      /*transpose_lhs=*/0,
551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      /*transpose_rhs=*/0);
561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return result;
571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<double>> ReferenceUtil::MatmulArray2D(
601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array2D<double>& lhs, const Array2D<double>& rhs) {
611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  CHECK_EQ(lhs.width(), rhs.height());
621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int m = lhs.height();
631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int n = rhs.width();
641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int k = lhs.width();
651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto result = MakeUnique<Array2D<double>>(m, n);
661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // Because Eigen is a header-oriented library, make sure that the Eigen code
671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // is the same as the code used by the CPU backend (otherwise the linker will
681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // randomly pick *some* definition).
691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  __xla_cpu_runtime_EigenSingleThreadedMatMulF64(
701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      /*run_options_ptr=*/nullptr, result->data(), rhs.data(), lhs.data(), n, m,
711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      k,
721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      /*transpose_lhs=*/0,
731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      /*transpose_rhs=*/0);
741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return result;
751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<double>> ReferenceUtil::Array2DF32ToF64(
781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array2D<float>& input) {
791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto result = MakeUnique<Array2D<double>>(input.height(), input.width());
801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  for (int64 rowno = 0; rowno < input.height(); ++rowno) {
811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    for (int64 colno = 0; colno < input.height(); ++colno) {
821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      (*result)(rowno, colno) = input(rowno, colno);
831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return result;
861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::ConvArray4D(
891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array4D<float>& lhs, const Array4D<float>& rhs,
901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    std::pair<int64, int64> kernel_stride, Padding padding) {
911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return ConvArray4DGeneralDimensions(
921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      lhs, rhs, kernel_stride, padding,
931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ComputationBuilder::CreateDefaultConvDimensionNumbers());
941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
967fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee/* static */ std::unique_ptr<Array4D<float>>
977fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong LeeReferenceUtil::SeparableConvArray4D(const Array4D<float>& input,
987fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee                                    const Array4D<float>& depthwise_weights,
997fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee                                    const Array4D<float>& pointwise_weights,
1007fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee                                    std::pair<int64, int64> kernel_stride,
1017fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee                                    Padding padding) {
1027fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee  const int64 depth_multiplier = depthwise_weights.planes();
1037fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee  CHECK_EQ(pointwise_weights.depth(), input.depth() * depth_multiplier);
1047fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee
1057fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee  // Combine the two weights by reducing the depth_multiplier, so that we can
1067fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee  // apply a single convolution on the combined weights.
1077fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee  Array4D<float> weights(pointwise_weights.planes(), input.depth(),
1087fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee                         depthwise_weights.height(), depthwise_weights.width());
1097fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee  for (int64 kx = 0; kx < depthwise_weights.width(); ++kx) {
1107fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee    for (int64 ky = 0; ky < depthwise_weights.height(); ++ky) {
1117fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee      for (int64 kz = 0; kz < input.depth(); ++kz) {
1127fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee        for (int64 out = 0; out < pointwise_weights.planes(); ++out) {
1137fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee          float weight = 0.0;
1147fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee          for (int64 depth = 0; depth < depth_multiplier; ++depth) {
1157fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee            weight +=
1167fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee                depthwise_weights(depth, kz, ky, kx) *
1177fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee                pointwise_weights(out, depth + kz * depth_multiplier, 0, 0);
1187fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee          }
1197fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee          weights(out, kz, ky, kx) = weight;
1207fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee        }
1217fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee      }
1227fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee    }
1237fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee  }
1247fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee
1257fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee  return ConvArray4D(input, weights, kernel_stride, padding);
1267fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee}
1277fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee
1281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ int64 ReferenceUtil::WindowCount(int64 unpadded_width,
1291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                              int64 window_len, int64 stride,
1301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                              Padding padding) {
1311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  if (padding == Padding::kValid) {
1321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    return window_util::StridedBound(unpadded_width, window_len, stride);
1331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
1341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return tensorflow::MathUtil::CeilOfRatio(unpadded_width, stride);
1351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
1361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1376bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi/* static  */ std::unique_ptr<Array2D<float>> ReferenceUtil::ReduceWindow2DAdd(
1386bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi    const Array2D<float>& operand, float init,
1396bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi    const tensorflow::gtl::ArraySlice<int64>& window,
1406bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi    const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding) {
1416bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi  std::vector<int64> dim_lengths{operand.height(), operand.width()};
1426bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi  auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding);
1436bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi
1446bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi  std::vector<int64> window_counts(window.size(), 0);
1456bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi  std::vector<int64> pad_low(window.size(), 0);
1466bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi  for (int64 i = 0; i < window.size(); ++i) {
1476bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi    window_counts[i] =
1486bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi        WindowCount(dim_lengths[i], window[i], stride[i], padding);
1496bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi    pad_low[i] = padding_both[i].first;
1506bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi  }
1516bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi  auto result = MakeUnique<Array2D<float>>(window_counts[0], window_counts[1]);
1526bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi
1536bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi  // Do a full 2D reduce window.
1546bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi  for (int64 i0 = 0; i0 < window_counts[0]; ++i0) {
1556bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi    for (int64 i1 = 0; i1 < window_counts[1]; ++i1) {
1566bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi      int64 i0_base = i0 * stride[0] - pad_low[0];
1576bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi      int64 i1_base = i1 * stride[1] - pad_low[1];
1586bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi
1596bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi      float val = init;
1606bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi      for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) {
1616bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi        for (int64 i1_win = 0; i1_win < window[1]; ++i1_win) {
1626bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi          if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 &&
1636bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi              i0_base + i0_win < operand.n1() &&
1646bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi              i1_base + i1_win < operand.n2()) {
1656bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi            val += operand(i0_base + i0_win, i1_base + i1_win);
1666bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi          }
1676bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi        }
1686bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi      }
1696bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi      (*result)(i0, i1) = val;
1706bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi    }
1716bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi  }
1726bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi  return result;
1736bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi}
1746bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi
1751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static  */ std::unique_ptr<Array4D<float>> ReferenceUtil::ReduceWindow4DAdd(
1761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array4D<float>& operand, float init,
1771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const tensorflow::gtl::ArraySlice<int64>& window,
1781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding) {
1791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(),
1801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                 operand.n4()};
1811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding);
1821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  std::vector<int64> window_counts(window.size(), 0);
1841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  std::vector<int64> pad_low(window.size(), 0);
1851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  for (int64 i = 0; i < window.size(); ++i) {
1861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    window_counts[i] =
1871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        WindowCount(dim_lengths[i], window[i], stride[i], padding);
1881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    pad_low[i] = padding_both[i].first;
1891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
1901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto result = MakeUnique<Array4D<float>>(window_counts[0], window_counts[1],
1911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                           window_counts[2], window_counts[3]);
1921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // Do a full 4D reduce window.
1931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  for (int64 i0 = 0; i0 < window_counts[0]; ++i0) {
1941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    for (int64 i1 = 0; i1 < window_counts[1]; ++i1) {
1951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      for (int64 i2 = 0; i2 < window_counts[2]; ++i2) {
1961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        for (int64 i3 = 0; i3 < window_counts[3]; ++i3) {
1971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          int64 i0_base = i0 * stride[0] - pad_low[0];
1981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          int64 i1_base = i1 * stride[1] - pad_low[1];
1991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          int64 i2_base = i2 * stride[2] - pad_low[2];
2001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          int64 i3_base = i3 * stride[3] - pad_low[3];
2011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
2021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          float val = init;
2031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) {
2041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins            for (int64 i1_win = 0; i1_win < window[1]; ++i1_win) {
2051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins              for (int64 i2_win = 0; i2_win < window[2]; ++i2_win) {
2061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                for (int64 i3_win = 0; i3_win < window[3]; ++i3_win) {
2071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                  if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 &&
2081e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      i2_base + i2_win >= 0 && i3_base + i3_win >= 0 &&
2091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      i0_base + i0_win < operand.n1() &&
2101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      i1_base + i1_win < operand.n2() &&
2111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      i2_base + i2_win < operand.n3() &&
2121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      i3_base + i3_win < operand.n4()) {
2131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                    val += operand(i0_base + i0_win, i1_base + i1_win,
2141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                   i2_base + i2_win, i3_base + i3_win);
2151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                  }
2161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                }
2171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins              }
2181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins            }
2191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          }
2201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          (*result)(i0, i1, i2, i3) = val;
2211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        }
2221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      }
2231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
2241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
2251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return result;
2261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
2271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
2281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static  */ std::unique_ptr<Array4D<float>>
2291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsReferenceUtil::SelectAndScatter4DGePlus(
2301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array4D<float>& operand, const Array4D<float>& source, float init,
2311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const tensorflow::gtl::ArraySlice<int64>& window,
2321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const tensorflow::gtl::ArraySlice<int64>& stride, bool same_padding) {
2331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  Padding padding = same_padding ? Padding::kSame : Padding::kValid;
2341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto result = MakeUnique<Array4D<float>>(operand.n1(), operand.n2(),
2351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                           operand.n3(), operand.n4());
2361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(),
2371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                 operand.n4()};
2381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding);
2391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // Fill the output, with the initial value.
2401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  result->Fill(init);
2411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
2421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  std::vector<int64> window_counts(window.size(), 0);
2431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  std::vector<int64> pad_low(window.size(), 0);
2441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  for (int64 i = 0; i < window.size(); ++i) {
2451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    window_counts[i] =
2461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        WindowCount(dim_lengths[i], window[i], stride[i], padding);
2471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    pad_low[i] = padding_both[i].first;
2481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
2491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  CHECK_EQ(window_counts[0], source.n1());
2501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  CHECK_EQ(window_counts[1], source.n2());
2511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  CHECK_EQ(window_counts[2], source.n3());
2521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  CHECK_EQ(window_counts[3], source.n4());
2531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
2541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // Do a full 4D select and Scatter.
2551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  for (int64 i0 = 0; i0 < window_counts[0]; ++i0) {
2561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    for (int64 i1 = 0; i1 < window_counts[1]; ++i1) {
2571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      for (int64 i2 = 0; i2 < window_counts[2]; ++i2) {
2581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        for (int64 i3 = 0; i3 < window_counts[3]; ++i3) {
2591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          // Now we are inside a window and need to find the max and the argmax.
2601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          int64 i0_base = i0 * stride[0] - pad_low[0];
2611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          int64 i1_base = i1 * stride[1] - pad_low[1];
2621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          int64 i2_base = i2 * stride[2] - pad_low[2];
2631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          int64 i3_base = i3 * stride[3] - pad_low[3];
2641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          int64 scatter_0 = (i0_base >= 0) ? i0_base : 0;
2651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          int64 scatter_1 = (i1_base >= 0) ? i1_base : 0;
2661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          int64 scatter_2 = (i2_base >= 0) ? i2_base : 0;
2671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          int64 scatter_3 = (i3_base >= 0) ? i3_base : 0;
2681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          float val = operand(scatter_0, scatter_1, scatter_2, scatter_3);
2691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) {
2701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins            for (int64 i1_win = 0; i1_win < window[1]; ++i1_win) {
2711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins              for (int64 i2_win = 0; i2_win < window[2]; ++i2_win) {
2721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                for (int64 i3_win = 0; i3_win < window[3]; ++i3_win) {
2731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                  if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 &&
2741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      i2_base + i2_win >= 0 && i3_base + i3_win >= 0 &&
2751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      i0_base + i0_win < operand.n1() &&
2761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      i1_base + i1_win < operand.n2() &&
2771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      i2_base + i2_win < operand.n3() &&
2781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      i3_base + i3_win < operand.n4()) {
2791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                    float tmp = operand(i0_base + i0_win, i1_base + i1_win,
2801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                        i2_base + i2_win, i3_base + i3_win);
2811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                    if (tmp >= val) {
2821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      val = tmp;
2831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      scatter_0 = i0_base + i0_win;
2841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      scatter_1 = i1_base + i1_win;
2851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      scatter_2 = i2_base + i2_win;
2861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      scatter_3 = i3_base + i3_win;
2871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                    }
2881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                  }
2891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                }
2901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins              }
2911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins            }
2921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          }
2931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          (*result)(scatter_0, scatter_1, scatter_2, scatter_3) +=
2941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins              source(i0, i1, i2, i3);
2951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        }
2961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      }
2971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
2981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
2991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return result;
3001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
3011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
3021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array4D<float>>
3031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsReferenceUtil::ConvArray4DGeneralDimensions(
3041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array4D<float>& lhs, const Array4D<float>& rhs,
3051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    std::pair<int64, int64> kernel_stride, Padding padding,
3061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    ConvolutionDimensionNumbers dimension_numbers) {
3071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return ConvArray4DGeneralDimensionsDilated(lhs, rhs, kernel_stride, padding,
3081e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                             {1, 1}, {1, 1}, dimension_numbers);
3091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
3101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
3111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array4D<float>>
3121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsReferenceUtil::ConvArray4DGeneralDimensionsDilated(
3131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array4D<float>& lhs, const Array4D<float>& rhs,
3141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    std::pair<int64, int64> kernel_stride, Padding padding,
3151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    std::pair<int64, int64> lhs_dilation, std::pair<int64, int64> rhs_dilation,
3161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    ConvolutionDimensionNumbers dnums) {
3171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  std::array<int64, 4> lhs_dimensions{{lhs.n1(), lhs.n2(), lhs.n3(), lhs.n4()}};
3181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  std::array<int64, 4> rhs_dimensions{{rhs.n1(), rhs.n2(), rhs.n3(), rhs.n4()}};
3191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
3201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  const int64 ksy = kernel_stride.first;
3211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  const int64 ksx = kernel_stride.second;
3221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  const int64 dy = lhs_dilation.first;
3231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  const int64 dx = lhs_dilation.second;
3241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  const int64 dky = rhs_dilation.first;
3251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  const int64 dkx = rhs_dilation.second;
3261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  CHECK_GE(dky, 1);
3271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  CHECK_GE(dkx, 1);
3281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  CHECK_GE(dy, 1);
3291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  CHECK_GE(dx, 1);
3301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
3311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // Get all dimension sizes in lhs and rhs based on the given convolution
3321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // dimension configuration.
3331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  const int64 ix = window_util::DilatedBound(
3341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      lhs_dimensions[dnums.spatial_dimensions(1)], dx);
3351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  const int64 iy = window_util::DilatedBound(
3361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      lhs_dimensions[dnums.spatial_dimensions(0)], dy);
3371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  const int64 iz = lhs_dimensions[dnums.feature_dimension()];
3381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  const int64 samples = lhs_dimensions[dnums.batch_dimension()];
3391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  const int64 kx = window_util::DilatedBound(
3401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      rhs_dimensions[dnums.kernel_spatial_dimensions(1)], dkx);
3411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  const int64 ky = window_util::DilatedBound(
3421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      rhs_dimensions[dnums.kernel_spatial_dimensions(0)], dky);
3431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  const int64 oz = rhs_dimensions[dnums.kernel_output_feature_dimension()];
3441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  {
3451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const int64 kiz = rhs_dimensions[dnums.kernel_input_feature_dimension()];
3461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    CHECK_EQ(kiz, iz);
3471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
3481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
3491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  if (padding == Padding::kSame) {
3501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    // We reject same padding with kernel striding, since it's somewhat
3511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    // nonsensical. We can always follow up to implement this with the desired
3521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    // semantics if anybody actually uses it.
3531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    CHECK_EQ(1, ksy);
3541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    CHECK_EQ(1, ksx);
3551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
3561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
3571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  const int64 ox =
3581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      padding == Padding::kSame ? ix : window_util::StridedBound(ix, kx, ksx);
3591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  const int64 oy =
3601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      padding == Padding::kSame ? iy : window_util::StridedBound(iy, ky, ksy);
3611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  const int64 istartx =
3621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      padding == Padding::kValid ? 0 : kx % 2 == 0 ? -(kx / 2 - 1) : -kx / 2;
3631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  const int64 istarty =
3641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      padding == Padding::kValid ? 0 : ky % 2 == 0 ? -(ky / 2 - 1) : -ky / 2;
3651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // Create the output result array and reset the values to 0.
3661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  std::array<int64, 4> result_dimensions;
3671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  result_dimensions[dnums.batch_dimension()] = samples;
3681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  result_dimensions[dnums.feature_dimension()] = oz;
3691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  result_dimensions[dnums.spatial_dimensions(0)] = oy;
3701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  result_dimensions[dnums.spatial_dimensions(1)] = ox;
3711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto result =
3721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      MakeUnique<Array4D<float>>(result_dimensions[0], result_dimensions[1],
3731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                 result_dimensions[2], result_dimensions[3]);
3741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  result->Fill(0.0);
3751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
3767135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar  const auto is_int32 = [](int64 x) {
3777135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar    return x >= std::numeric_limits<int32>::min() &&
3787135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar           x <= std::numeric_limits<int32>::max();
3797135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar  };
3807135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar
3817135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar  // 64-bit idiv/mod are much more expensive x86-64 than 32-bit idiv/imod (at
3827135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar  // least on x86-64), so we avoid them where possible.
3837135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar  const auto fast_idiv64 = [&](int64 a, int64 b) {
3847135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar    if (is_int32(a) && is_int32(b)) {
3857135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar      return static_cast<int64>(static_cast<int32>(a) / static_cast<int32>(b));
3867135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar    }
3877135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar    return a / b;
3887135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar  };
3897135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar  const auto fast_imod64 = [&](int64 a, int64 b) {
3907135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar    if (is_int32(a) && is_int32(b)) {
3917135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar      return static_cast<int64>(static_cast<int32>(a) % static_cast<int32>(b));
3927135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar    }
3937135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar    return a % b;
3947135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar  };
3957135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar
3961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // Lambda to access the lhs operand at the given 4D index.
3971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  const auto lhs_element = [&](int64 batch, int64 feature, int64 height,
3981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                               int64 width) {
3997135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar    if (fast_imod64(height, dy) != 0 || fast_imod64(width, dx) != 0) {
4001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      return 0.0f;
4011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
4021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
4031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    std::array<int64, 4> index;
4041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    index[dnums.batch_dimension()] = batch;
4051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    index[dnums.feature_dimension()] = feature;
4067135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar    index[dnums.spatial_dimensions(0)] = fast_idiv64(height, dy);
4077135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar    index[dnums.spatial_dimensions(1)] = fast_idiv64(width, dx);
4081e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    return lhs(index[0], index[1], index[2], index[3]);
4091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  };
4101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
411f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar  // Lambda to access the rhs operand at the given 4D index.  height_over_dky
412f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar  // should be equal to height / dky, and width_over_dkx should be equal to
413f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar  // width / dkx.  (This is an optimization to avoid doing divisions.)
414f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar  const auto rhs_element = [&](
415f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar      int64 kernel_output_feature, int64 kernel_input_feature, int64 height,
416f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar      int64 width, int64 height_over_dky, int64 width_over_dkx) {
417f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar    DCHECK_EQ(height % dky, 0);
418f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar    DCHECK_EQ(width % dkx, 0);
419f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar    DCHECK_EQ(height / dky, height_over_dky);
420f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar    DCHECK_EQ(width / dkx, width_over_dkx);
421f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar
4221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    std::array<int64, 4> index;
4231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    index[dnums.kernel_output_feature_dimension()] = kernel_output_feature;
4241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    index[dnums.kernel_input_feature_dimension()] = kernel_input_feature;
425f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar    index[dnums.kernel_spatial_dimensions(0)] = height_over_dky;
426f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar    index[dnums.kernel_spatial_dimensions(1)] = width_over_dkx;
4271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    return rhs(index[0], index[1], index[2], index[3]);
4281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  };
4291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
4301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // Lambda to access the result data at the given 4D index.
4311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  const auto result_element = [&](int64 batch, int64 kernel_output_feature,
4321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                  int64 height, int64 width) -> float& {
4331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    std::array<int64, 4> index;
4341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    index[dnums.batch_dimension()] = batch;
4351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    index[dnums.feature_dimension()] = kernel_output_feature;
4361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    index[dnums.spatial_dimensions(0)] = height;
4371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    index[dnums.spatial_dimensions(1)] = width;
4381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    return (*result)(index[0], index[1], index[2], index[3]);
4391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  };
4401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
4411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  for (int64 oyi = 0; oyi < oy; ++oyi) {
4421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    for (int64 oxi = 0; oxi < ox; ++oxi) {
4431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      for (int64 sample = 0; sample < samples; ++sample) {
4441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        for (int64 izi = 0; izi < iz; ++izi) {
4451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          for (int64 ozi = 0; ozi < oz; ++ozi) {
446f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar            for (int64 kyi = 0, kyi_over_dky = 0; kyi < ky;
447f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar                 kyi += dky, kyi_over_dky++) {
448f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar              for (int64 kxi = 0, kxi_over_dkx = 0; kxi < kx;
449f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar                   kxi += dkx, kxi_over_dkx++) {
4501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                int64 iyi = istarty + ksy * oyi + kyi;
4511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                int64 ixi = istartx + ksx * oxi + kxi;
4521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                float input = (iyi >= iy || ixi >= ix || iyi < 0 || ixi < 0)
4531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                  ? 0.0
4541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                  : lhs_element(sample, izi, iyi, ixi);
455f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar                float gain =
456f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar                    rhs_element(ozi, izi, kyi, kxi, kyi_over_dky, kxi_over_dkx);
4571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                float addend = input * gain;
4581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                result_element(sample, ozi, oyi, oxi) += addend;
4591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins              }
4601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins            }
4611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          }
4621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        }
4631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      }
4641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
4651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
4661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return result;
4671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
4681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
4691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<std::vector<float>>
4701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsReferenceUtil::ReduceToColArray2D(
4711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array2D<float>& matrix, float init,
4721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    std::function<float(float, float)> reduce_function) {
4731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 rows = matrix.height();
4741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 cols = matrix.width();
4751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto result = MakeUnique<std::vector<float>>();
4761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  for (int64 i = 0; i < rows; ++i) {
4771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    float acc = init;
4781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    for (int64 j = 0; j < cols; ++j) {
4791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      acc = reduce_function(acc, matrix(i, j));
4801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
4811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    result->push_back(acc);
4821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
4831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return result;
4841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
4851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
4861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<std::vector<float>>
4871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsReferenceUtil::ReduceToRowArray2D(
4881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array2D<float>& matrix, float init,
4891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    std::function<float(float, float)> reduce_function) {
4901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 rows = matrix.height();
4911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 cols = matrix.width();
4921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto result = MakeUnique<std::vector<float>>();
4931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  for (int64 i = 0; i < cols; ++i) {
4941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    float acc = init;
4951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    for (int64 j = 0; j < rows; ++j) {
4961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      acc = reduce_function(acc, matrix(j, i));
4971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
4981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    result->push_back(acc);
4991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
5001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return result;
5011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
5021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
5031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/*static*/ std::vector<float> ReferenceUtil::Reduce4DTo1D(
5041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array4D<float>& array, float init,
5051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    tensorflow::gtl::ArraySlice<int64> dims,
5061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    std::function<float(float, float)> reduce_function) {
5071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  std::vector<float> result;
5081e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  CHECK_EQ(dims.size(), 3);
5091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  const std::set<int64> dim_set(dims.begin(), dims.end());
5101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  CHECK_EQ(dim_set.size(), 3);
5111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  for (int64 a0 = 0; a0 == 0 || (!dim_set.count(0) && a0 < array.n1()); ++a0) {
5121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    for (int64 a1 = 0; a1 == 0 || (!dim_set.count(1) && a1 < array.n2());
5131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins         ++a1) {
5141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      for (int64 a2 = 0; a2 == 0 || (!dim_set.count(2) && a2 < array.n3());
5151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins           ++a2) {
5161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        for (int64 a3 = 0; a3 == 0 || (!dim_set.count(3) && a3 < array.n4());
5171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins             ++a3) {
5181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          float accumulator = init;
5191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          for (int64 i0 = 0; i0 == 0 || (dim_set.count(0) && i0 < array.n1());
5201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins               ++i0) {
5211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins            for (int64 i1 = 0; i1 == 0 || (dim_set.count(1) && i1 < array.n2());
5221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                 ++i1) {
5231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins              for (int64 i2 = 0;
5241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                   i2 == 0 || (dim_set.count(2) && i2 < array.n3()); ++i2) {
5251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                for (int64 i3 = 0;
5261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                     i3 == 0 || (dim_set.count(3) && i3 < array.n4()); ++i3) {
5271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                  accumulator = reduce_function(
5281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      accumulator, array(a0 + i0, a1 + i1, a2 + i2, a3 + i3));
5291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                }
5301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins              }
5311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins            }
5321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          }
5331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          result.push_back(accumulator);
5341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        }
5351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      }
5361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
5371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
5381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return result;
5391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
5401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
5411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::Reduce3DTo2D(
5421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array3D<float>& array, float init,
5431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    tensorflow::gtl::ArraySlice<int64> dims,
5441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    std::function<float(float, float)> reduce_function) {
5451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  CHECK_EQ(dims.size(), 1);
5461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 rows = dims[0] == 0 ? array.n2() : array.n1();
5471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 cols = dims[0] == 2 ? array.n2() : array.n3();
5481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto result = MakeUnique<Array2D<float>>(rows, cols);
5491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  result->Fill(init);
5501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  for (int i0 = 0; i0 < array.n1(); ++i0) {
5511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    for (int i1 = 0; i1 < array.n2(); ++i1) {
5521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      for (int i2 = 0; i2 < array.n3(); ++i2) {
5531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        int64 row = dims[0] == 0 ? i1 : i0;
5541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        int64 col = dims[0] == 2 ? i1 : i2;
5551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        (*result)(row, col) =
5561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins            reduce_function((*result)(row, col), array(i0, i1, i2));
5571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      }
5581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
5591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
5601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return result;
5611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
5621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
5631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MapArray2D(
5641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array2D<float>& matrix,
5651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const std::function<float(float)>& map_function) {
5661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 rows = matrix.height();
5671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 cols = matrix.width();
5681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto result = MakeUnique<Array2D<float>>(rows, cols);
5691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  for (int64 i = 0; i < rows; ++i) {
5701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    for (int64 j = 0; j < cols; ++j) {
5711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      (*result)(i, j) = map_function(matrix(i, j));
5721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
5731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
5741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return result;
5751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
5761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
5771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MapArray2D(
5781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array2D<float>& lhs, const Array2D<float>& rhs,
5791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const std::function<float(float, float)>& map_function) {
5801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  CHECK_EQ(lhs.height(), rhs.height());
5811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  CHECK_EQ(lhs.width(), rhs.width());
5821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 rows = lhs.height();
5831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 cols = rhs.width();
5841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto result = MakeUnique<Array2D<float>>(rows, cols);
5851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  for (int64 i = 0; i < rows; ++i) {
5861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    for (int64 j = 0; j < cols; ++j) {
5871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      (*result)(i, j) = map_function(lhs(i, j), rhs(i, j));
5881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
5891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
5901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return result;
5911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
5921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
5931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MapWithIndexArray2D(
5941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array2D<float>& matrix,
5951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const std::function<float(float, int64, int64)>& map_function) {
5961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 rows = matrix.height();
5971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 cols = matrix.width();
5981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto result = MakeUnique<Array2D<float>>(rows, cols);
5991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  for (int64 i = 0; i < rows; ++i) {
6001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    for (int64 j = 0; j < cols; ++j) {
6011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      (*result)(i, j) = map_function(matrix(i, j), i, j);
6021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
6031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
6041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return result;
6051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
6061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
6071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::PadArray2D(
6081e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array2D<float>& operand, const PaddingConfig& padding,
6091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const float pad) {
6101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 in0 = operand.n1();
6111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 high_padding0 = padding.dimensions(0).edge_padding_high();
6121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 low_padding0 = padding.dimensions(0).edge_padding_low();
6131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 interior_padding0 = padding.dimensions(0).interior_padding();
6141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 out0 =
6151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      in0 + low_padding0 + high_padding0 + (in0 - 1) * interior_padding0;
6165aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan
6171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 in1 = operand.n2();
6181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 high_padding1 = padding.dimensions(1).edge_padding_high();
6191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 low_padding1 = padding.dimensions(1).edge_padding_low();
6201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 interior_padding1 = padding.dimensions(1).interior_padding();
6211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 out1 =
6221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      in1 + low_padding1 + high_padding1 + (in1 - 1) * interior_padding1;
6235aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan
6241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto result = MakeUnique<Array2D<float>>(out0, out1);
6251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  result->Fill(pad);
6265aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan  int64 o0 = low_padding0;
6275aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan  for (int64 i0 = 0; i0 < in0; ++i0) {
6285aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan    int64 o1 = low_padding1;
6295aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan    for (int64 i1 = 0; i1 < in1; ++i1) {
6305aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan      if (o0 >= 0 && o1 >= 0 && o0 < out0 && o1 < out1) {
6315aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan        (*result)(o0, o1) = operand(i0, i1);
6325aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan      }
6335aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan      o1 += interior_padding1 + 1;
6341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
6355aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan    o0 += interior_padding0 + 1;
6361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
6371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return result;
6381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
6391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
6401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}  // namespace xla
641