reference_util.cc revision b1c10555afe9ad4ebebbd83eb31dbf8006d7980b
11e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
21e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
31e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsLicensed under the Apache License, Version 2.0 (the "License");
41e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsyou may not use this file except in compliance with the License.
51e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsYou may obtain a copy of the License at
61e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
71e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    http://www.apache.org/licenses/LICENSE-2.0
81e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
91e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsUnless required by applicable law or agreed to in writing, software
101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsdistributed under the License is distributed on an "AS IS" BASIS,
111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsSee the License for the specific language governing permissions and
131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinslimitations under the License.
141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins==============================================================================*/
151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/reference_util.h"
171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include <array>
19f4b8d21b8e41636b6e61f0a1de753430108d2ee7A. Unique TensorFlower#include <utility>
201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/client/computation_builder.h"
221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h"
23253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
24253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu#include "tensorflow/compiler/xla/service/hlo_instruction.h"
25253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu#include "tensorflow/compiler/xla/service/shape_inference.h"
261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/window_util.h"
271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/xla_data.pb.h"
281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/lib/math/math_util.h"
291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/platform/logging.h"
301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsnamespace xla {
321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::TransposeArray2D(
341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array2D<float>& operand) {
351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto result = MakeUnique<Array2D<float>>(operand.width(), operand.height());
361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  for (int64 w = 0; w < operand.width(); ++w) {
371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    for (int64 h = 0; h < operand.height(); ++h) {
381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      (*result)(w, h) = operand(h, w);
391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return result;
431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MatmulArray2D(
461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array2D<float>& lhs, const Array2D<float>& rhs) {
471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  CHECK_EQ(lhs.width(), rhs.height());
481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int m = lhs.height();
491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int n = rhs.width();
501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int k = lhs.width();
511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto result = MakeUnique<Array2D<float>>(m, n);
521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // Because Eigen is a header-oriented library, make sure that the Eigen code
531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // is the same as the code used by the CPU backend (otherwise the linker will
541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // randomly pick *some* definition).
551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  __xla_cpu_runtime_EigenSingleThreadedMatMulF32(
561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      /*run_options_ptr=*/nullptr, result->data(), rhs.data(), lhs.data(), n, m,
571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      k,
581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      /*transpose_lhs=*/0,
591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      /*transpose_rhs=*/0);
601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return result;
611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<double>> ReferenceUtil::MatmulArray2D(
641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array2D<double>& lhs, const Array2D<double>& rhs) {
651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  CHECK_EQ(lhs.width(), rhs.height());
661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int m = lhs.height();
671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int n = rhs.width();
681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int k = lhs.width();
691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto result = MakeUnique<Array2D<double>>(m, n);
701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // Because Eigen is a header-oriented library, make sure that the Eigen code
711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // is the same as the code used by the CPU backend (otherwise the linker will
721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // randomly pick *some* definition).
731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  __xla_cpu_runtime_EigenSingleThreadedMatMulF64(
741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      /*run_options_ptr=*/nullptr, result->data(), rhs.data(), lhs.data(), n, m,
751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      k,
761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      /*transpose_lhs=*/0,
771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      /*transpose_rhs=*/0);
781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return result;
791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<double>> ReferenceUtil::Array2DF32ToF64(
821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array2D<float>& input) {
831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto result = MakeUnique<Array2D<double>>(input.height(), input.width());
841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  for (int64 rowno = 0; rowno < input.height(); ++rowno) {
851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    for (int64 colno = 0; colno < input.height(); ++colno) {
861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      (*result)(rowno, colno) = input(rowno, colno);
871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return result;
901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
929b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower/*  static */ std::unique_ptr<Array3D<float>> ReferenceUtil::ConvArray3D(
939b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower    const Array3D<float>& lhs, const Array3D<float>& rhs, int64 kernel_stride,
949b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower    Padding padding) {
959b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower  return ConvArray3DGeneralDimensionsDilated(
969b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower      lhs, rhs, kernel_stride, padding, 1, 1,
979b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower      ComputationBuilder::CreateDefaultConvDimensionNumbers(1));
989b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower}
999b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower
1009b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower/*static*/ std::unique_ptr<Array3D<float>>
1019b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlowerReferenceUtil::ConvArray3DGeneralDimensionsDilated(
1029b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower    const Array3D<float>& lhs, const Array3D<float>& rhs, int64 kernel_stride,
1039b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower    Padding padding, int64 lhs_dilation, int64 rhs_dilation,
1049b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower    const ConvolutionDimensionNumbers& dnums) {
1059b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower  CHECK_EQ(dnums.spatial_dimensions_size(), 1);
1069b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower  // Reuse the code for Array4D-convolution by extending the 3D input into a 4D
1079b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower  // array by adding a fourth dummy dimension of size 1 without stride, padding
1089b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower  // and dilation.
1099b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower  Array4D<float> a4dlhs(lhs.n1(), lhs.n2(), lhs.n3(), 1);
1109b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower  a4dlhs.Each(
1119b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower      [&](tensorflow::gtl::ArraySlice<int64> indices, float* value_ptr) {
1129b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower        CHECK_EQ(indices[3], 0);
1139b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower        *value_ptr = lhs.operator()(indices[0], indices[1], indices[2]);
1149b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower      });
1159b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower  Array4D<float> a4drhs(rhs.n1(), rhs.n2(), rhs.n3(), 1);
1169b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower  a4drhs.Each(
1179b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower      [&](tensorflow::gtl::ArraySlice<int64> indices, float* value_ptr) {
1189b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower        CHECK_EQ(indices[3], 0);
1199b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower        *value_ptr = rhs.operator()(indices[0], indices[1], indices[2]);
1209b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower      });
1219b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower  // Add a second dummy spatial dimensions.
1229b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower  ConvolutionDimensionNumbers dnums2d = dnums;
1239b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower  dnums2d.add_spatial_dimensions(3);
1249b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower  dnums2d.add_kernel_spatial_dimensions(3);
1259b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower  std::unique_ptr<Array4D<float>> convr4 = ConvArray4DGeneralDimensionsDilated(
1269b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower      a4dlhs, a4drhs, {kernel_stride, 1}, padding, {lhs_dilation, 1},
1279b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower      {rhs_dilation, 1}, dnums2d);
1289b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower
1299b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower  auto convr3 = MakeUnique<Array3D<float>>(convr4->planes(), convr4->depth(),
1309b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower                                           convr4->height());
1319b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower  convr4->Each(
1329b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower      [&](tensorflow::gtl::ArraySlice<int64> indices, float* value_ptr) {
1339b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower        CHECK_EQ(indices[3], 0);
1349b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower        convr3->operator()(indices[0], indices[1], indices[2]) = *value_ptr;
1359b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower      });
1369b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower  return convr3;
1379b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower}
1389b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower
1391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::ConvArray4D(
1401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array4D<float>& lhs, const Array4D<float>& rhs,
1411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    std::pair<int64, int64> kernel_stride, Padding padding) {
1421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return ConvArray4DGeneralDimensions(
1431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      lhs, rhs, kernel_stride, padding,
1441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ComputationBuilder::CreateDefaultConvDimensionNumbers());
1451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
1461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1477fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee/* static */ std::unique_ptr<Array4D<float>>
1487fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong LeeReferenceUtil::SeparableConvArray4D(const Array4D<float>& input,
1497fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee                                    const Array4D<float>& depthwise_weights,
1507fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee                                    const Array4D<float>& pointwise_weights,
1517fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee                                    std::pair<int64, int64> kernel_stride,
1527fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee                                    Padding padding) {
1537fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee  const int64 depth_multiplier = depthwise_weights.planes();
1547fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee  CHECK_EQ(pointwise_weights.depth(), input.depth() * depth_multiplier);
1557fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee
1567fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee  // Combine the two weights by reducing the depth_multiplier, so that we can
1577fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee  // apply a single convolution on the combined weights.
1587fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee  Array4D<float> weights(pointwise_weights.planes(), input.depth(),
1597fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee                         depthwise_weights.height(), depthwise_weights.width());
1607fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee  for (int64 kx = 0; kx < depthwise_weights.width(); ++kx) {
1617fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee    for (int64 ky = 0; ky < depthwise_weights.height(); ++ky) {
1627fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee      for (int64 kz = 0; kz < input.depth(); ++kz) {
1637fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee        for (int64 out = 0; out < pointwise_weights.planes(); ++out) {
1647fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee          float weight = 0.0;
1657fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee          for (int64 depth = 0; depth < depth_multiplier; ++depth) {
1667fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee            weight +=
1677fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee                depthwise_weights(depth, kz, ky, kx) *
1687fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee                pointwise_weights(out, depth + kz * depth_multiplier, 0, 0);
1697fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee          }
1707fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee          weights(out, kz, ky, kx) = weight;
1717fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee        }
1727fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee      }
1737fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee    }
1747fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee  }
1757fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee
1767fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee  return ConvArray4D(input, weights, kernel_stride, padding);
1777fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee}
1787fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee
1791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ int64 ReferenceUtil::WindowCount(int64 unpadded_width,
1801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                              int64 window_len, int64 stride,
1811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                              Padding padding) {
1821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  if (padding == Padding::kValid) {
1831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    return window_util::StridedBound(unpadded_width, window_len, stride);
1841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
1851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return tensorflow::MathUtil::CeilOfRatio(unpadded_width, stride);
1861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
1871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1880034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi/* static  */ std::unique_ptr<std::vector<float>>
1890034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo OguntebiReferenceUtil::ReduceWindow1DGeneric(
1900034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi    const tensorflow::gtl::ArraySlice<float>& operand, float init,
1910034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi    const std::function<float(float, float)>& reduce_func,
1920034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi    const tensorflow::gtl::ArraySlice<int64>& window,
1930034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi    const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding) {
1940034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi  std::vector<int64> dim_lengths{static_cast<int64>(operand.size())};
1950034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi  auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding);
1960034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi
1970034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi  std::vector<int64> window_counts(window.size(), 0);
1980034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi  std::vector<int64> pad_low(window.size(), 0);
1990034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi  for (int64 i = 0; i < window.size(); ++i) {
2000034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi    window_counts[i] =
2010034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi        WindowCount(dim_lengths[i], window[i], stride[i], padding);
2020034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi    pad_low[i] = padding_both[i].first;
2030034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi  }
2040034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi  auto result = MakeUnique<std::vector<float>>(window_counts[0]);
2050034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi
2060034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi  // Do a full 1D reduce window.
2070034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi  for (int64 i0 = 0; i0 < window_counts[0]; ++i0) {
2080034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi    int64 i0_base = i0 * stride[0] - pad_low[0];
2090034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi
2100034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi    float val = init;
2110034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi    for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) {
2120034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi      if (i0_base + i0_win >= 0 && i0_base + i0_win < dim_lengths[0]) {
2130034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi        val = reduce_func(val, operand[i0_base + i0_win]);
2140034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi      }
2150034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi    }
2160034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi    (*result)[i0] = val;
2170034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi  }
2180034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi  return result;
2190034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi}
2200034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi
2210034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi/* static  */ std::unique_ptr<std::vector<float>>
2220034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo OguntebiReferenceUtil::ReduceWindow1DAdd(
2230034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi    const tensorflow::gtl::ArraySlice<float>& operand, float init,
2240034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi    const tensorflow::gtl::ArraySlice<int64>& window,
2250034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi    const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding) {
2260034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi  const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; };
2270034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi  return ReduceWindow1DGeneric(operand, init, add_reduce, window, stride,
2280034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi                               padding);
2290034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi}
2300034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi
2316bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi/* static  */ std::unique_ptr<Array2D<float>> ReferenceUtil::ReduceWindow2DAdd(
2326bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi    const Array2D<float>& operand, float init,
2336bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi    const tensorflow::gtl::ArraySlice<int64>& window,
2346bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi    const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding) {
2356bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi  std::vector<int64> dim_lengths{operand.height(), operand.width()};
2366bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi  auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding);
2376bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi
2386bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi  std::vector<int64> window_counts(window.size(), 0);
2396bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi  std::vector<int64> pad_low(window.size(), 0);
2406bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi  for (int64 i = 0; i < window.size(); ++i) {
2416bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi    window_counts[i] =
2426bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi        WindowCount(dim_lengths[i], window[i], stride[i], padding);
2436bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi    pad_low[i] = padding_both[i].first;
2446bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi  }
2456bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi  auto result = MakeUnique<Array2D<float>>(window_counts[0], window_counts[1]);
2466bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi
2476bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi  // Do a full 2D reduce window.
2486bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi  for (int64 i0 = 0; i0 < window_counts[0]; ++i0) {
2496bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi    for (int64 i1 = 0; i1 < window_counts[1]; ++i1) {
2506bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi      int64 i0_base = i0 * stride[0] - pad_low[0];
2516bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi      int64 i1_base = i1 * stride[1] - pad_low[1];
2526bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi
2536bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi      float val = init;
2546bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi      for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) {
2556bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi        for (int64 i1_win = 0; i1_win < window[1]; ++i1_win) {
2566bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi          if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 &&
2576bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi              i0_base + i0_win < operand.n1() &&
2586bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi              i1_base + i1_win < operand.n2()) {
2596bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi            val += operand(i0_base + i0_win, i1_base + i1_win);
2606bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi          }
2616bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi        }
2626bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi      }
2636bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi      (*result)(i0, i1) = val;
2646bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi    }
2656bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi  }
2666bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi  return result;
2676bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi}
2686bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi
2692d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi/* static */ std::unique_ptr<Array4D<float>>
2702d69270342d2a5e46446e02e9273e7da79f00accTayo OguntebiReferenceUtil::ReduceWindow4DGeneric(
2711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array4D<float>& operand, float init,
2722d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi    const std::function<float(float, float)>& reduce_func,
2731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const tensorflow::gtl::ArraySlice<int64>& window,
2741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding) {
2751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(),
2761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                 operand.n4()};
277f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune  return ReduceWindow4DGeneric(
278f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune      operand, init, reduce_func, window, stride,
279f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune      xla::MakePadding(dim_lengths, window, stride, padding));
280f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune}
281f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune
282f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune/* static */ std::unique_ptr<Array4D<float>>
283f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt RouneReferenceUtil::ReduceWindow4DGeneric(
284f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune    const Array4D<float>& operand, float init,
285f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune    const std::function<float(float, float)>& reduce_func,
286f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune    const tensorflow::gtl::ArraySlice<int64>& window,
287f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune    const tensorflow::gtl::ArraySlice<int64>& stride,
288f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune    const tensorflow::gtl::ArraySlice<std::pair<int64, int64>>& padding) {
289f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune  std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(),
290f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune                                 operand.n4()};
2911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
2921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  std::vector<int64> window_counts(window.size(), 0);
2931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  std::vector<int64> pad_low(window.size(), 0);
2941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  for (int64 i = 0; i < window.size(); ++i) {
295f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune    int64 padded_width = padding[i].first + dim_lengths[i] + padding[i].second;
2961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    window_counts[i] =
297f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune        window_util::StridedBound(padded_width, window[i], stride[i]);
298f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune    pad_low[i] = padding[i].first;
2991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
3001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto result = MakeUnique<Array4D<float>>(window_counts[0], window_counts[1],
3011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                           window_counts[2], window_counts[3]);
3021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // Do a full 4D reduce window.
3031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  for (int64 i0 = 0; i0 < window_counts[0]; ++i0) {
3041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    for (int64 i1 = 0; i1 < window_counts[1]; ++i1) {
3051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      for (int64 i2 = 0; i2 < window_counts[2]; ++i2) {
3061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        for (int64 i3 = 0; i3 < window_counts[3]; ++i3) {
3071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          int64 i0_base = i0 * stride[0] - pad_low[0];
3081e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          int64 i1_base = i1 * stride[1] - pad_low[1];
3091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          int64 i2_base = i2 * stride[2] - pad_low[2];
3101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          int64 i3_base = i3 * stride[3] - pad_low[3];
3111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
3121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          float val = init;
3131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) {
3141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins            for (int64 i1_win = 0; i1_win < window[1]; ++i1_win) {
3151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins              for (int64 i2_win = 0; i2_win < window[2]; ++i2_win) {
3161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                for (int64 i3_win = 0; i3_win < window[3]; ++i3_win) {
3171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                  if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 &&
3181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      i2_base + i2_win >= 0 && i3_base + i3_win >= 0 &&
3191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      i0_base + i0_win < operand.n1() &&
3201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      i1_base + i1_win < operand.n2() &&
3211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      i2_base + i2_win < operand.n3() &&
3221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      i3_base + i3_win < operand.n4()) {
3232d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi                    val = reduce_func(
3242d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi                        val, operand(i0_base + i0_win, i1_base + i1_win,
3252d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi                                     i2_base + i2_win, i3_base + i3_win));
3261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                  }
3271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                }
3281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins              }
3291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins            }
3301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          }
3311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          (*result)(i0, i1, i2, i3) = val;
3321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        }
3331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      }
3341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
3351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
3361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return result;
3371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
3381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
3392d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi/* static  */ std::unique_ptr<Array4D<float>> ReferenceUtil::ReduceWindow4DAdd(
3402d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi    const Array4D<float>& operand, float init,
3412d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi    const tensorflow::gtl::ArraySlice<int64>& window,
3422d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi    const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding) {
3432d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi  const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; };
3442d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi  return ReduceWindow4DGeneric(operand, init, add_reduce, window, stride,
3452d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi                               padding);
3462d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi}
3472d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi
3481464b9930de871fd11870941963253670f737c23A. Unique TensorFlower/* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::BatchNorm4D(
3491464b9930de871fd11870941963253670f737c23A. Unique TensorFlower    const Array4D<float>& input, const Array4D<float>& mean,
3501464b9930de871fd11870941963253670f737c23A. Unique TensorFlower    const Array4D<float>& var, const Array4D<float>& scale,
3511464b9930de871fd11870941963253670f737c23A. Unique TensorFlower    const Array4D<float>& offset, float epsilon) {
3521464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto normalized =
3531464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      *MapArray4D(input, mean, [](float a, float b) { return a - b; });
3541464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  normalized = *MapArray4D(normalized, var, [&](float a, float b) {
3551464b9930de871fd11870941963253670f737c23A. Unique TensorFlower    return a / std::sqrt(b + epsilon);
3561464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  });
3571464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  normalized =
3581464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      *MapArray4D(normalized, scale, [](float a, float b) { return a * b; });
3591464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  return MapArray4D(normalized, offset, [](float a, float b) { return a + b; });
3601464b9930de871fd11870941963253670f737c23A. Unique TensorFlower}
3611464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
3621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static  */ std::unique_ptr<Array4D<float>>
3631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsReferenceUtil::SelectAndScatter4DGePlus(
3641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array4D<float>& operand, const Array4D<float>& source, float init,
3651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const tensorflow::gtl::ArraySlice<int64>& window,
3661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const tensorflow::gtl::ArraySlice<int64>& stride, bool same_padding) {
3671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  Padding padding = same_padding ? Padding::kSame : Padding::kValid;
3681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto result = MakeUnique<Array4D<float>>(operand.n1(), operand.n2(),
3691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                           operand.n3(), operand.n4());
3701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(),
3711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                 operand.n4()};
3721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding);
3731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // Fill the output, with the initial value.
3741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  result->Fill(init);
3751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
3761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  std::vector<int64> window_counts(window.size(), 0);
3771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  std::vector<int64> pad_low(window.size(), 0);
3781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  for (int64 i = 0; i < window.size(); ++i) {
3791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    window_counts[i] =
3801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        WindowCount(dim_lengths[i], window[i], stride[i], padding);
3811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    pad_low[i] = padding_both[i].first;
3821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
3831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  CHECK_EQ(window_counts[0], source.n1());
3841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  CHECK_EQ(window_counts[1], source.n2());
3851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  CHECK_EQ(window_counts[2], source.n3());
3861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  CHECK_EQ(window_counts[3], source.n4());
3871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
3881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // Do a full 4D select and Scatter.
3891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  for (int64 i0 = 0; i0 < window_counts[0]; ++i0) {
3901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    for (int64 i1 = 0; i1 < window_counts[1]; ++i1) {
3911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      for (int64 i2 = 0; i2 < window_counts[2]; ++i2) {
3921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        for (int64 i3 = 0; i3 < window_counts[3]; ++i3) {
3931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          // Now we are inside a window and need to find the max and the argmax.
3941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          int64 i0_base = i0 * stride[0] - pad_low[0];
3951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          int64 i1_base = i1 * stride[1] - pad_low[1];
3961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          int64 i2_base = i2 * stride[2] - pad_low[2];
3971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          int64 i3_base = i3 * stride[3] - pad_low[3];
3981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          int64 scatter_0 = (i0_base >= 0) ? i0_base : 0;
3991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          int64 scatter_1 = (i1_base >= 0) ? i1_base : 0;
4001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          int64 scatter_2 = (i2_base >= 0) ? i2_base : 0;
4011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          int64 scatter_3 = (i3_base >= 0) ? i3_base : 0;
4021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          float val = operand(scatter_0, scatter_1, scatter_2, scatter_3);
4031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) {
4041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins            for (int64 i1_win = 0; i1_win < window[1]; ++i1_win) {
4051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins              for (int64 i2_win = 0; i2_win < window[2]; ++i2_win) {
4061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                for (int64 i3_win = 0; i3_win < window[3]; ++i3_win) {
4071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                  if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 &&
4081e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      i2_base + i2_win >= 0 && i3_base + i3_win >= 0 &&
4091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      i0_base + i0_win < operand.n1() &&
4101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      i1_base + i1_win < operand.n2() &&
4111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      i2_base + i2_win < operand.n3() &&
4121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      i3_base + i3_win < operand.n4()) {
4131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                    float tmp = operand(i0_base + i0_win, i1_base + i1_win,
4141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                        i2_base + i2_win, i3_base + i3_win);
4151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                    if (tmp >= val) {
4161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      val = tmp;
4171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      scatter_0 = i0_base + i0_win;
4181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      scatter_1 = i1_base + i1_win;
4191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      scatter_2 = i2_base + i2_win;
4201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      scatter_3 = i3_base + i3_win;
4211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                    }
4221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                  }
4231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                }
4241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins              }
4251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins            }
4261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          }
4271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          (*result)(scatter_0, scatter_1, scatter_2, scatter_3) +=
4281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins              source(i0, i1, i2, i3);
4291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        }
4301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      }
4311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
4321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
4331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return result;
4341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
4351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
4361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array4D<float>>
4371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsReferenceUtil::ConvArray4DGeneralDimensions(
4381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array4D<float>& lhs, const Array4D<float>& rhs,
4391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    std::pair<int64, int64> kernel_stride, Padding padding,
4401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    ConvolutionDimensionNumbers dimension_numbers) {
4411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return ConvArray4DGeneralDimensionsDilated(lhs, rhs, kernel_stride, padding,
442f4b8d21b8e41636b6e61f0a1de753430108d2ee7A. Unique TensorFlower                                             {1, 1}, {1, 1},
443f4b8d21b8e41636b6e61f0a1de753430108d2ee7A. Unique TensorFlower                                             std::move(dimension_numbers));
4441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
4451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
4461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array4D<float>>
4471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsReferenceUtil::ConvArray4DGeneralDimensionsDilated(
4481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array4D<float>& lhs, const Array4D<float>& rhs,
4491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    std::pair<int64, int64> kernel_stride, Padding padding,
4501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    std::pair<int64, int64> lhs_dilation, std::pair<int64, int64> rhs_dilation,
4511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    ConvolutionDimensionNumbers dnums) {
452253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  HloComputation::Builder b("ConvArray4DGeneralDimensionDilated");
453253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  auto lhs_literal = Literal::CreateR4FromArray4D<float>(lhs);
454253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  auto rhs_literal = Literal::CreateR4FromArray4D<float>(rhs);
455253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu
456253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  std::array<int64, 2> ordered_kernel_strides;
457253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  std::array<int64, 2> ordered_input_dimensions;
458253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  std::array<int64, 2> ordered_kernel_dimensions;
459253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  if (dnums.kernel_spatial_dimensions(0) > dnums.kernel_spatial_dimensions(1)) {
460253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu    ordered_kernel_strides[0] = kernel_stride.second;
461253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu    ordered_kernel_strides[1] = kernel_stride.first;
462253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  } else {
463253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu    ordered_kernel_strides[0] = kernel_stride.first;
464253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu    ordered_kernel_strides[1] = kernel_stride.second;
4651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
4661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
467253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  ordered_input_dimensions[0] =
468253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu      lhs_literal->shape().dimensions(dnums.spatial_dimensions(0));
469253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  ordered_input_dimensions[1] =
470253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu      lhs_literal->shape().dimensions(dnums.spatial_dimensions(1));
471253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  ordered_kernel_dimensions[0] =
472253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu      rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(0));
473253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  ordered_kernel_dimensions[1] =
474253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu      rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(1));
475253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu
476253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  std::vector<std::pair<int64, int64>> paddings =
477253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu      MakePadding(ordered_input_dimensions, ordered_kernel_dimensions,
478253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu                  ordered_kernel_strides, padding);
479253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  CHECK_EQ(paddings.size(), 2);
480253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu
481253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  Window window;
482253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu
483253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  WindowDimension dim;
484253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  dim.set_size(
485253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu      rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(0)));
486253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  dim.set_stride(kernel_stride.first);
487253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  dim.set_padding_low(paddings[0].first);
488253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  dim.set_padding_high(paddings[0].second);
489253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  dim.set_window_dilation(rhs_dilation.first);
490253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  dim.set_base_dilation(lhs_dilation.first);
491253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  *window.add_dimensions() = dim;
492253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu
493253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  WindowDimension dim2;
494253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  dim2.set_size(
495253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu      rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(1)));
496253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  dim2.set_stride(kernel_stride.second);
497253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  dim2.set_padding_low(paddings[1].first);
498253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  dim2.set_padding_high(paddings[1].second);
499253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  dim2.set_window_dilation(rhs_dilation.second);
500253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  dim2.set_base_dilation(lhs_dilation.second);
501253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  *window.add_dimensions() = dim2;
502253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu
503253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  const Shape& shape =
504253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu      ShapeInference::InferConvolveShape(lhs_literal->shape(),
505253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu                                         rhs_literal->shape(), window, dnums)
506253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu          .ConsumeValueOrDie();
507253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu
508253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  HloInstruction* lhs_instruction =
509253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu      b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
510253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  HloInstruction* rhs_instruction =
511253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu      b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
512253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu
513253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  b.AddInstruction(HloInstruction::CreateConvolve(
514253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu      shape, lhs_instruction, rhs_instruction, window, dnums));
515b1c10555afe9ad4ebebbd83eb31dbf8006d7980bMark Heffernan  HloModule module("ReferenceUtil");
516b1c10555afe9ad4ebebbd83eb31dbf8006d7980bMark Heffernan  auto computation = module.AddEntryComputation(b.Build());
517253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu
518253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  HloEvaluator evaluator;
519253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  std::unique_ptr<Literal> result_literal =
520b1c10555afe9ad4ebebbd83eb31dbf8006d7980bMark Heffernan      evaluator.Evaluate(*computation, {}).ConsumeValueOrDie();
521253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu
522253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  CHECK_EQ(ShapeUtil::Rank(result_literal->shape()), 4);
523b01346de8b5893a09d50ff4d9c80ca442a327a76Kay Zhu  auto result =
524253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu      MakeUnique<Array4D<float>>(result_literal->shape().dimensions(0),
525253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu                                 result_literal->shape().dimensions(1),
526253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu                                 result_literal->shape().dimensions(2),
527253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu                                 result_literal->shape().dimensions(3));
528253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu
529253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  result->Each([&](tensorflow::gtl::ArraySlice<int64> indices, float* value) {
530253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu    *value = result_literal->Get<float>(indices);
531253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  });
5321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
5331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return result;
5341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
5351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
5361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<std::vector<float>>
5371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsReferenceUtil::ReduceToColArray2D(
5381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array2D<float>& matrix, float init,
5397e08b5c7ae5a10810948095b96d2da53e816e446Eli Bendersky    const std::function<float(float, float)>& reduce_function) {
5401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 rows = matrix.height();
5411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 cols = matrix.width();
5421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto result = MakeUnique<std::vector<float>>();
5431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  for (int64 i = 0; i < rows; ++i) {
5441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    float acc = init;
5451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    for (int64 j = 0; j < cols; ++j) {
5461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      acc = reduce_function(acc, matrix(i, j));
5471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
5481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    result->push_back(acc);
5491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
5501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return result;
5511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
5521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
5531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<std::vector<float>>
5541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsReferenceUtil::ReduceToRowArray2D(
5551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array2D<float>& matrix, float init,
5567e08b5c7ae5a10810948095b96d2da53e816e446Eli Bendersky    const std::function<float(float, float)>& reduce_function) {
5571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 rows = matrix.height();
5581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 cols = matrix.width();
5591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto result = MakeUnique<std::vector<float>>();
5601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  for (int64 i = 0; i < cols; ++i) {
5611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    float acc = init;
5621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    for (int64 j = 0; j < rows; ++j) {
5631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      acc = reduce_function(acc, matrix(j, i));
5641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
5651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    result->push_back(acc);
5661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
5671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return result;
5681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
5691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
5701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/*static*/ std::vector<float> ReferenceUtil::Reduce4DTo1D(
5711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array4D<float>& array, float init,
5721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    tensorflow::gtl::ArraySlice<int64> dims,
5737e08b5c7ae5a10810948095b96d2da53e816e446Eli Bendersky    const std::function<float(float, float)>& reduce_function) {
5741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  std::vector<float> result;
5751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  CHECK_EQ(dims.size(), 3);
5761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  const std::set<int64> dim_set(dims.begin(), dims.end());
5771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  CHECK_EQ(dim_set.size(), 3);
5781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  for (int64 a0 = 0; a0 == 0 || (!dim_set.count(0) && a0 < array.n1()); ++a0) {
5791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    for (int64 a1 = 0; a1 == 0 || (!dim_set.count(1) && a1 < array.n2());
5801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins         ++a1) {
5811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      for (int64 a2 = 0; a2 == 0 || (!dim_set.count(2) && a2 < array.n3());
5821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins           ++a2) {
5831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        for (int64 a3 = 0; a3 == 0 || (!dim_set.count(3) && a3 < array.n4());
5841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins             ++a3) {
5851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          float accumulator = init;
5861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          for (int64 i0 = 0; i0 == 0 || (dim_set.count(0) && i0 < array.n1());
5871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins               ++i0) {
5881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins            for (int64 i1 = 0; i1 == 0 || (dim_set.count(1) && i1 < array.n2());
5891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                 ++i1) {
5901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins              for (int64 i2 = 0;
5911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                   i2 == 0 || (dim_set.count(2) && i2 < array.n3()); ++i2) {
5921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                for (int64 i3 = 0;
5931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                     i3 == 0 || (dim_set.count(3) && i3 < array.n4()); ++i3) {
5941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                  accumulator = reduce_function(
5951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      accumulator, array(a0 + i0, a1 + i1, a2 + i2, a3 + i3));
5961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                }
5971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins              }
5981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins            }
5991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          }
6001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          result.push_back(accumulator);
6011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        }
6021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      }
6031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
6041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
6051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return result;
6061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
6071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
6081464b9930de871fd11870941963253670f737c23A. Unique TensorFlower/* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::Broadcast1DTo4D(
6091464b9930de871fd11870941963253670f737c23A. Unique TensorFlower    const std::vector<float>& array, const std::vector<int64>& bounds,
6101464b9930de871fd11870941963253670f737c23A. Unique TensorFlower    int64 broadcast_from_dim) {
6111464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto result =
6121464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      MakeUnique<Array4D<float>>(bounds[0], bounds[1], bounds[2], bounds[3]);
6131464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  for (int64 i = 0; i < result->n1(); ++i) {
6141464b9930de871fd11870941963253670f737c23A. Unique TensorFlower    for (int64 j = 0; j < result->n2(); ++j) {
6151464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      for (int64 k = 0; k < result->n3(); ++k) {
6161464b9930de871fd11870941963253670f737c23A. Unique TensorFlower        for (int64 l = 0; l < result->n4(); ++l) {
6171464b9930de871fd11870941963253670f737c23A. Unique TensorFlower          switch (broadcast_from_dim) {
6181464b9930de871fd11870941963253670f737c23A. Unique TensorFlower            case 0:
6191464b9930de871fd11870941963253670f737c23A. Unique TensorFlower              (*result)(i, j, k, l) = array[i];
6201464b9930de871fd11870941963253670f737c23A. Unique TensorFlower              break;
6211464b9930de871fd11870941963253670f737c23A. Unique TensorFlower            case 1:
6221464b9930de871fd11870941963253670f737c23A. Unique TensorFlower              (*result)(i, j, k, l) = array[j];
6231464b9930de871fd11870941963253670f737c23A. Unique TensorFlower              break;
6241464b9930de871fd11870941963253670f737c23A. Unique TensorFlower            case 2:
6251464b9930de871fd11870941963253670f737c23A. Unique TensorFlower              (*result)(i, j, k, l) = array[k];
6261464b9930de871fd11870941963253670f737c23A. Unique TensorFlower              break;
6271464b9930de871fd11870941963253670f737c23A. Unique TensorFlower            case 3:
6281464b9930de871fd11870941963253670f737c23A. Unique TensorFlower              (*result)(i, j, k, l) = array[l];
6291464b9930de871fd11870941963253670f737c23A. Unique TensorFlower              break;
6301464b9930de871fd11870941963253670f737c23A. Unique TensorFlower            default:
6311464b9930de871fd11870941963253670f737c23A. Unique TensorFlower              break;
6321464b9930de871fd11870941963253670f737c23A. Unique TensorFlower          }
6331464b9930de871fd11870941963253670f737c23A. Unique TensorFlower        }
6341464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      }
6351464b9930de871fd11870941963253670f737c23A. Unique TensorFlower    }
6361464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  }
6371464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  return result;
6381464b9930de871fd11870941963253670f737c23A. Unique TensorFlower}
6391464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
6401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::Reduce3DTo2D(
6411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array3D<float>& array, float init,
6421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    tensorflow::gtl::ArraySlice<int64> dims,
6437e08b5c7ae5a10810948095b96d2da53e816e446Eli Bendersky    const std::function<float(float, float)>& reduce_function) {
6441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  CHECK_EQ(dims.size(), 1);
6451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 rows = dims[0] == 0 ? array.n2() : array.n1();
6461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 cols = dims[0] == 2 ? array.n2() : array.n3();
6471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto result = MakeUnique<Array2D<float>>(rows, cols);
6481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  result->Fill(init);
6491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  for (int i0 = 0; i0 < array.n1(); ++i0) {
6501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    for (int i1 = 0; i1 < array.n2(); ++i1) {
6511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      for (int i2 = 0; i2 < array.n3(); ++i2) {
6521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        int64 row = dims[0] == 0 ? i1 : i0;
6531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        int64 col = dims[0] == 2 ? i1 : i2;
6541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        (*result)(row, col) =
6551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins            reduce_function((*result)(row, col), array(i0, i1, i2));
6561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      }
6571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
6581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
6591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return result;
6601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
6611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
6621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MapArray2D(
6631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array2D<float>& matrix,
6641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const std::function<float(float)>& map_function) {
6651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 rows = matrix.height();
6661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 cols = matrix.width();
6671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto result = MakeUnique<Array2D<float>>(rows, cols);
6681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  for (int64 i = 0; i < rows; ++i) {
6691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    for (int64 j = 0; j < cols; ++j) {
6701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      (*result)(i, j) = map_function(matrix(i, j));
6711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
6721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
6731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return result;
6741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
6751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
6761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MapArray2D(
6771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array2D<float>& lhs, const Array2D<float>& rhs,
6781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const std::function<float(float, float)>& map_function) {
6791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  CHECK_EQ(lhs.height(), rhs.height());
6801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  CHECK_EQ(lhs.width(), rhs.width());
6811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 rows = lhs.height();
6821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 cols = rhs.width();
6831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto result = MakeUnique<Array2D<float>>(rows, cols);
6841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  for (int64 i = 0; i < rows; ++i) {
6851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    for (int64 j = 0; j < cols; ++j) {
6861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      (*result)(i, j) = map_function(lhs(i, j), rhs(i, j));
6871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
6881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
6891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return result;
6901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
6911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
6921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MapWithIndexArray2D(
6931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array2D<float>& matrix,
6941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const std::function<float(float, int64, int64)>& map_function) {
6951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 rows = matrix.height();
6961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 cols = matrix.width();
6971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto result = MakeUnique<Array2D<float>>(rows, cols);
6981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  for (int64 i = 0; i < rows; ++i) {
6991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    for (int64 j = 0; j < cols; ++j) {
7001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      (*result)(i, j) = map_function(matrix(i, j), i, j);
7011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
7021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
7031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return result;
7041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
7051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
7061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::PadArray2D(
7071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array2D<float>& operand, const PaddingConfig& padding,
7081e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const float pad) {
7091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 in0 = operand.n1();
7101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 high_padding0 = padding.dimensions(0).edge_padding_high();
7111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 low_padding0 = padding.dimensions(0).edge_padding_low();
7121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 interior_padding0 = padding.dimensions(0).interior_padding();
7131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 out0 =
7141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      in0 + low_padding0 + high_padding0 + (in0 - 1) * interior_padding0;
7155aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan
7161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 in1 = operand.n2();
7171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 high_padding1 = padding.dimensions(1).edge_padding_high();
7181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 low_padding1 = padding.dimensions(1).edge_padding_low();
7191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 interior_padding1 = padding.dimensions(1).interior_padding();
7201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 out1 =
7211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      in1 + low_padding1 + high_padding1 + (in1 - 1) * interior_padding1;
7225aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan
7231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto result = MakeUnique<Array2D<float>>(out0, out1);
7241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  result->Fill(pad);
7255aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan  int64 o0 = low_padding0;
7265aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan  for (int64 i0 = 0; i0 < in0; ++i0) {
7275aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan    int64 o1 = low_padding1;
7285aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan    for (int64 i1 = 0; i1 < in1; ++i1) {
7295aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan      if (o0 >= 0 && o1 >= 0 && o0 < out0 && o1 < out1) {
7305aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan        (*result)(o0, o1) = operand(i0, i1);
7315aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan      }
7325aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan      o1 += interior_padding1 + 1;
7331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
7345aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan    o0 += interior_padding0 + 1;
7351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
7361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return result;
7371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
7381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
7399b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower/* static */ Array3D<float> ReferenceUtil::PadArray3D(
7409b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower    const Array3D<float>& operand, const PaddingConfig& padding,
7419b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower    const float pad) {
7429b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower  CHECK_EQ(padding.dimensions_size(), 3);
7439b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower
7449b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower  const std::vector<int64> input_bounds = {operand.n1(), operand.n2(),
7459b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower                                           operand.n3()};
7469b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower  std::vector<int64> pad_low(3);
7479b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower  std::vector<int64> pad_high(3);
7489b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower  std::vector<int64> pad_interior(3);
7499b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower  std::vector<int64> output_bounds(3);
7509b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower  for (int64 i = 0; i < 3; ++i) {
7519b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower    pad_low[i] = padding.dimensions(i).edge_padding_low();
7529b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower    pad_high[i] = padding.dimensions(i).edge_padding_high();
7539b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower    CHECK_LE(0, pad_low[i]);
7549b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower    CHECK_LE(0, pad_high[i]);
7559b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower    CHECK_LE(0, padding.dimensions(i).interior_padding()) << "not implemented";
7569b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower    pad_interior[i] = padding.dimensions(i).interior_padding();
7579b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower
7589b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower    output_bounds[i] = pad_low[i] + input_bounds[i] + pad_high[i] +
7599b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower                       (input_bounds[i] - 1) * pad_interior[i];
7609b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower  }
7619b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower
7629b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower  Array3D<float> result(output_bounds[0], output_bounds[1], output_bounds[2]);
7639b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower  std::vector<int> indices = {0, 0, 0};
7649b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower  for (indices[0] = 0; indices[0] < output_bounds[0]; ++indices[0]) {
7659b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower    for (indices[1] = 0; indices[1] < output_bounds[1]; ++indices[1]) {
7669b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower      for (indices[2] = 0; indices[2] < output_bounds[2]; ++indices[2]) {
7679b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower        float* value = &result(indices[0], indices[1], indices[2]);
7689b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower        bool value_padded = false;
7699b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower        for (int i = 0; i < 3; ++i) {
7709b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower          bool in_low_padding = indices[i] < pad_low[i];
7719b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower          bool in_high_padding = indices[i] >= output_bounds[i] - pad_high[i];
7729b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower          if (in_low_padding || in_high_padding) {
7739b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower            *value = pad;
7749b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower            value_padded = true;
7759b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower          }
7769b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower          if (pad_interior[i] &&
7779b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower              (indices[i] - pad_low[i]) % (pad_interior[i] + 1)) {
7789b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower            *value = pad;
7799b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower            value_padded = true;
7809b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower          }
7819b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower        }
7829b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower        if (value_padded) {
7839b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower          continue;
7849b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower        }
7859b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower        *value = operand((indices[0] - pad_low[0]) / (pad_interior[0] + 1),
7869b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower                         (indices[1] - pad_low[1]) / (pad_interior[1] + 1),
7879b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower                         (indices[2] - pad_low[2]) / (pad_interior[2] + 1));
7889b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower      }
7899b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower    }
7909b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower  }
7919b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower  return result;
7929b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower}
7939b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower
794c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune/* static */ Array4D<float> ReferenceUtil::PadArray4D(
795c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune    const Array4D<float>& operand, const PaddingConfig& padding,
796c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune    const float pad) {
797c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune  CHECK_EQ(padding.dimensions_size(), 4);
798c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune
799c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune  const std::vector<int64> input_bounds = {operand.n1(), operand.n2(),
800c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune                                           operand.n3(), operand.n4()};
801c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune  std::vector<int64> pad_low(4);
802c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune  std::vector<int64> pad_high(4);
8038ac9eb6318084fcb3d9dc71a7a7fc05bf18048d8A. Unique TensorFlower  std::vector<int64> pad_interior(4);
804c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune  std::vector<int64> output_bounds(4);
805c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune  for (int64 i = 0; i < 4; ++i) {
806c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune    pad_low[i] = padding.dimensions(i).edge_padding_low();
807c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune    pad_high[i] = padding.dimensions(i).edge_padding_high();
8088ac9eb6318084fcb3d9dc71a7a7fc05bf18048d8A. Unique TensorFlower    CHECK_LE(0, padding.dimensions(i).interior_padding()) << "not implemented";
8098ac9eb6318084fcb3d9dc71a7a7fc05bf18048d8A. Unique TensorFlower    pad_interior[i] = padding.dimensions(i).interior_padding();
810c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune
8118ac9eb6318084fcb3d9dc71a7a7fc05bf18048d8A. Unique TensorFlower    output_bounds[i] = pad_low[i] + input_bounds[i] + pad_high[i] +
8128ac9eb6318084fcb3d9dc71a7a7fc05bf18048d8A. Unique TensorFlower                       (input_bounds[i] - 1) * pad_interior[i];
813c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune  }
814c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune
815c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune  Array4D<float> result(output_bounds[0], output_bounds[1], output_bounds[2],
816c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune                        output_bounds[3]);
817c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune  result.Each([&](tensorflow::gtl::ArraySlice<int64> indices, float* value) {
818c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune    for (int i = 0; i < 4; ++i) {
819c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune      bool in_low_padding = indices[i] < pad_low[i];
820c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune      bool in_high_padding = indices[i] >= output_bounds[i] - pad_high[i];
821c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune      if (in_low_padding || in_high_padding) {
822c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune        *value = pad;
823c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune        return;
824c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune      }
8258ac9eb6318084fcb3d9dc71a7a7fc05bf18048d8A. Unique TensorFlower      if (pad_interior[i] &&
8268ac9eb6318084fcb3d9dc71a7a7fc05bf18048d8A. Unique TensorFlower          (indices[i] - pad_low[i]) % (pad_interior[i] + 1)) {
8278ac9eb6318084fcb3d9dc71a7a7fc05bf18048d8A. Unique TensorFlower        *value = pad;
8288ac9eb6318084fcb3d9dc71a7a7fc05bf18048d8A. Unique TensorFlower        return;
8298ac9eb6318084fcb3d9dc71a7a7fc05bf18048d8A. Unique TensorFlower      }
830c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune    }
8318ac9eb6318084fcb3d9dc71a7a7fc05bf18048d8A. Unique TensorFlower    *value = operand((indices[0] - pad_low[0]) / (pad_interior[0] + 1),
8328ac9eb6318084fcb3d9dc71a7a7fc05bf18048d8A. Unique TensorFlower                     (indices[1] - pad_low[1]) / (pad_interior[1] + 1),
8338ac9eb6318084fcb3d9dc71a7a7fc05bf18048d8A. Unique TensorFlower                     (indices[2] - pad_low[2]) / (pad_interior[2] + 1),
8348ac9eb6318084fcb3d9dc71a7a7fc05bf18048d8A. Unique TensorFlower                     (indices[3] - pad_low[3]) / (pad_interior[3] + 1));
835c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune  });
836c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune  return result;
837c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune}
838c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune
8391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}  // namespace xla
840