11e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
21e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
31e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsLicensed under the Apache License, Version 2.0 (the "License");
41e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsyou may not use this file except in compliance with the License.
51e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsYou may obtain a copy of the License at
61e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
71e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    http://www.apache.org/licenses/LICENSE-2.0
81e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
91e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsUnless required by applicable law or agreed to in writing, software
101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsdistributed under the License is distributed on an "AS IS" BASIS,
111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsSee the License for the specific language governing permissions and
131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinslimitations under the License.
141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins==============================================================================*/
151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/reference_util.h"
171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include <array>
19f4b8d21b8e41636b6e61f0a1de753430108d2ee7A. Unique TensorFlower#include <utility>
201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/client/computation_builder.h"
221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h"
23253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
24253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu#include "tensorflow/compiler/xla/service/hlo_instruction.h"
25253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu#include "tensorflow/compiler/xla/service/shape_inference.h"
261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/window_util.h"
271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/xla_data.pb.h"
281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/lib/math/math_util.h"
291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/platform/logging.h"
301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsnamespace xla {
321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::TransposeArray2D(
341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array2D<float>& operand) {
351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto result = MakeUnique<Array2D<float>>(operand.width(), operand.height());
361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  for (int64 w = 0; w < operand.width(); ++w) {
371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    for (int64 h = 0; h < operand.height(); ++h) {
381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      (*result)(w, h) = operand(h, w);
391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return result;
431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MatmulArray2D(
461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array2D<float>& lhs, const Array2D<float>& rhs) {
471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  CHECK_EQ(lhs.width(), rhs.height());
481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int m = lhs.height();
491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int n = rhs.width();
501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int k = lhs.width();
511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto result = MakeUnique<Array2D<float>>(m, n);
521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // Because Eigen is a header-oriented library, make sure that the Eigen code
531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // is the same as the code used by the CPU backend (otherwise the linker will
541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // randomly pick *some* definition).
551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  __xla_cpu_runtime_EigenSingleThreadedMatMulF32(
561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      /*run_options_ptr=*/nullptr, result->data(), rhs.data(), lhs.data(), n, m,
571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      k,
581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      /*transpose_lhs=*/0,
591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      /*transpose_rhs=*/0);
601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return result;
611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<double>> ReferenceUtil::MatmulArray2D(
641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array2D<double>& lhs, const Array2D<double>& rhs) {
651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  CHECK_EQ(lhs.width(), rhs.height());
661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int m = lhs.height();
671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int n = rhs.width();
681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int k = lhs.width();
691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto result = MakeUnique<Array2D<double>>(m, n);
701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // Because Eigen is a header-oriented library, make sure that the Eigen code
711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // is the same as the code used by the CPU backend (otherwise the linker will
721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // randomly pick *some* definition).
731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  __xla_cpu_runtime_EigenSingleThreadedMatMulF64(
741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      /*run_options_ptr=*/nullptr, result->data(), rhs.data(), lhs.data(), n, m,
751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      k,
761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      /*transpose_lhs=*/0,
771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      /*transpose_rhs=*/0);
781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return result;
791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<double>> ReferenceUtil::Array2DF32ToF64(
821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array2D<float>& input) {
831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto result = MakeUnique<Array2D<double>>(input.height(), input.width());
841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  for (int64 rowno = 0; rowno < input.height(); ++rowno) {
851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    for (int64 colno = 0; colno < input.height(); ++colno) {
861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      (*result)(rowno, colno) = input(rowno, colno);
871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return result;
901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
929b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower/*  static */ std::unique_ptr<Array3D<float>> ReferenceUtil::ConvArray3D(
939b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower    const Array3D<float>& lhs, const Array3D<float>& rhs, int64 kernel_stride,
949b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower    Padding padding) {
959b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower  return ConvArray3DGeneralDimensionsDilated(
969b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower      lhs, rhs, kernel_stride, padding, 1, 1,
979b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower      ComputationBuilder::CreateDefaultConvDimensionNumbers(1));
989b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower}
999b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower
1009b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower/*static*/ std::unique_ptr<Array3D<float>>
1019b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlowerReferenceUtil::ConvArray3DGeneralDimensionsDilated(
1029b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower    const Array3D<float>& lhs, const Array3D<float>& rhs, int64 kernel_stride,
1039b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower    Padding padding, int64 lhs_dilation, int64 rhs_dilation,
1049b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower    const ConvolutionDimensionNumbers& dnums) {
105102bfdfd830f4dab6e00371e63a82561e1246518David Majnemer  CHECK_EQ(dnums.input_spatial_dimensions_size(), 1);
106102bfdfd830f4dab6e00371e63a82561e1246518David Majnemer  CHECK_EQ(dnums.kernel_spatial_dimensions_size(), 1);
107102bfdfd830f4dab6e00371e63a82561e1246518David Majnemer  CHECK_EQ(dnums.output_spatial_dimensions_size(), 1);
1089b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower  // Reuse the code for Array4D-convolution by extending the 3D input into a 4D
1099b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower  // array by adding a fourth dummy dimension of size 1 without stride, padding
1109b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower  // and dilation.
1119b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower  Array4D<float> a4dlhs(lhs.n1(), lhs.n2(), lhs.n3(), 1);
1129b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower  a4dlhs.Each(
1139b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower      [&](tensorflow::gtl::ArraySlice<int64> indices, float* value_ptr) {
1149b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower        CHECK_EQ(indices[3], 0);
1159b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower        *value_ptr = lhs.operator()(indices[0], indices[1], indices[2]);
1169b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower      });
1179b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower  Array4D<float> a4drhs(rhs.n1(), rhs.n2(), rhs.n3(), 1);
1189b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower  a4drhs.Each(
1199b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower      [&](tensorflow::gtl::ArraySlice<int64> indices, float* value_ptr) {
1209b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower        CHECK_EQ(indices[3], 0);
1219b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower        *value_ptr = rhs.operator()(indices[0], indices[1], indices[2]);
1229b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower      });
1239b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower  // Add a second dummy spatial dimensions.
1249b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower  ConvolutionDimensionNumbers dnums2d = dnums;
125102bfdfd830f4dab6e00371e63a82561e1246518David Majnemer  dnums2d.add_input_spatial_dimensions(3);
1269b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower  dnums2d.add_kernel_spatial_dimensions(3);
127102bfdfd830f4dab6e00371e63a82561e1246518David Majnemer  dnums2d.add_output_spatial_dimensions(3);
1289b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower  std::unique_ptr<Array4D<float>> convr4 = ConvArray4DGeneralDimensionsDilated(
1299b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower      a4dlhs, a4drhs, {kernel_stride, 1}, padding, {lhs_dilation, 1},
1309b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower      {rhs_dilation, 1}, dnums2d);
1319b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower
1329b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower  auto convr3 = MakeUnique<Array3D<float>>(convr4->planes(), convr4->depth(),
1339b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower                                           convr4->height());
1349b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower  convr4->Each(
1359b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower      [&](tensorflow::gtl::ArraySlice<int64> indices, float* value_ptr) {
1369b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower        CHECK_EQ(indices[3], 0);
1379b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower        convr3->operator()(indices[0], indices[1], indices[2]) = *value_ptr;
1389b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower      });
1399b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower  return convr3;
1409b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower}
1419b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower
1421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::ConvArray4D(
1431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array4D<float>& lhs, const Array4D<float>& rhs,
1441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    std::pair<int64, int64> kernel_stride, Padding padding) {
1451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return ConvArray4DGeneralDimensions(
1461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      lhs, rhs, kernel_stride, padding,
1471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ComputationBuilder::CreateDefaultConvDimensionNumbers());
1481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
1491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1507fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee/* static */ std::unique_ptr<Array4D<float>>
1517fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong LeeReferenceUtil::SeparableConvArray4D(const Array4D<float>& input,
1527fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee                                    const Array4D<float>& depthwise_weights,
1537fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee                                    const Array4D<float>& pointwise_weights,
1547fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee                                    std::pair<int64, int64> kernel_stride,
1557fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee                                    Padding padding) {
1567fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee  const int64 depth_multiplier = depthwise_weights.planes();
1577fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee  CHECK_EQ(pointwise_weights.depth(), input.depth() * depth_multiplier);
1587fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee
1597fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee  // Combine the two weights by reducing the depth_multiplier, so that we can
1607fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee  // apply a single convolution on the combined weights.
1617fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee  Array4D<float> weights(pointwise_weights.planes(), input.depth(),
1627fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee                         depthwise_weights.height(), depthwise_weights.width());
1637fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee  for (int64 kx = 0; kx < depthwise_weights.width(); ++kx) {
1647fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee    for (int64 ky = 0; ky < depthwise_weights.height(); ++ky) {
1657fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee      for (int64 kz = 0; kz < input.depth(); ++kz) {
1667fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee        for (int64 out = 0; out < pointwise_weights.planes(); ++out) {
1677fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee          float weight = 0.0;
1687fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee          for (int64 depth = 0; depth < depth_multiplier; ++depth) {
1697fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee            weight +=
1707fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee                depthwise_weights(depth, kz, ky, kx) *
1717fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee                pointwise_weights(out, depth + kz * depth_multiplier, 0, 0);
1727fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee          }
1737fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee          weights(out, kz, ky, kx) = weight;
1747fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee        }
1757fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee      }
1767fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee    }
1777fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee  }
1787fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee
1797fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee  return ConvArray4D(input, weights, kernel_stride, padding);
1807fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee}
1817fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee
1821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ int64 ReferenceUtil::WindowCount(int64 unpadded_width,
1831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                              int64 window_len, int64 stride,
1841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                              Padding padding) {
1851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  if (padding == Padding::kValid) {
1861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    return window_util::StridedBound(unpadded_width, window_len, stride);
1871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
1881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return tensorflow::MathUtil::CeilOfRatio(unpadded_width, stride);
1891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
1901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1910034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi/* static  */ std::unique_ptr<std::vector<float>>
1920034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo OguntebiReferenceUtil::ReduceWindow1DGeneric(
1930034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi    const tensorflow::gtl::ArraySlice<float>& operand, float init,
1940034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi    const std::function<float(float, float)>& reduce_func,
1950034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi    const tensorflow::gtl::ArraySlice<int64>& window,
1960034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi    const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding) {
1970034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi  std::vector<int64> dim_lengths{static_cast<int64>(operand.size())};
198e2f9107effb0c5c4cee49a71562865d9e919b3d0Tayo Oguntebi  return ReduceWindow1DGeneric(
199e2f9107effb0c5c4cee49a71562865d9e919b3d0Tayo Oguntebi      operand, init, reduce_func, window, stride,
200e2f9107effb0c5c4cee49a71562865d9e919b3d0Tayo Oguntebi      xla::MakePadding(dim_lengths, window, stride, padding));
201e2f9107effb0c5c4cee49a71562865d9e919b3d0Tayo Oguntebi}
2020034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi
203e2f9107effb0c5c4cee49a71562865d9e919b3d0Tayo Oguntebi/* static  */ std::unique_ptr<std::vector<float>>
204e2f9107effb0c5c4cee49a71562865d9e919b3d0Tayo OguntebiReferenceUtil::ReduceWindow1DGeneric(
205e2f9107effb0c5c4cee49a71562865d9e919b3d0Tayo Oguntebi    const tensorflow::gtl::ArraySlice<float>& operand, float init,
206e2f9107effb0c5c4cee49a71562865d9e919b3d0Tayo Oguntebi    const std::function<float(float, float)>& reduce_func,
207e2f9107effb0c5c4cee49a71562865d9e919b3d0Tayo Oguntebi    const tensorflow::gtl::ArraySlice<int64>& window,
208e2f9107effb0c5c4cee49a71562865d9e919b3d0Tayo Oguntebi    const tensorflow::gtl::ArraySlice<int64>& stride,
209e2f9107effb0c5c4cee49a71562865d9e919b3d0Tayo Oguntebi    const tensorflow::gtl::ArraySlice<std::pair<int64, int64>>& padding) {
210e2f9107effb0c5c4cee49a71562865d9e919b3d0Tayo Oguntebi  std::vector<int64> dim_lengths{static_cast<int64>(operand.size())};
2110034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi  std::vector<int64> window_counts(window.size(), 0);
2120034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi  std::vector<int64> pad_low(window.size(), 0);
2130034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi  for (int64 i = 0; i < window.size(); ++i) {
214e2f9107effb0c5c4cee49a71562865d9e919b3d0Tayo Oguntebi    int64 padded_width = padding[i].first + dim_lengths[i] + padding[i].second;
2150034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi    window_counts[i] =
216e2f9107effb0c5c4cee49a71562865d9e919b3d0Tayo Oguntebi        window_util::StridedBound(padded_width, window[i], stride[i]);
217e2f9107effb0c5c4cee49a71562865d9e919b3d0Tayo Oguntebi    pad_low[i] = padding[i].first;
2180034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi  }
2190034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi  auto result = MakeUnique<std::vector<float>>(window_counts[0]);
2200034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi
2210034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi  // Do a full 1D reduce window.
2220034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi  for (int64 i0 = 0; i0 < window_counts[0]; ++i0) {
2230034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi    int64 i0_base = i0 * stride[0] - pad_low[0];
2240034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi
2250034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi    float val = init;
2260034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi    for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) {
2270034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi      if (i0_base + i0_win >= 0 && i0_base + i0_win < dim_lengths[0]) {
2280034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi        val = reduce_func(val, operand[i0_base + i0_win]);
2290034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi      }
2300034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi    }
2310034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi    (*result)[i0] = val;
2320034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi  }
2330034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi  return result;
2340034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi}
2350034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi
2360034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi/* static  */ std::unique_ptr<std::vector<float>>
2370034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo OguntebiReferenceUtil::ReduceWindow1DAdd(
2380034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi    const tensorflow::gtl::ArraySlice<float>& operand, float init,
2390034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi    const tensorflow::gtl::ArraySlice<int64>& window,
2400034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi    const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding) {
2410034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi  const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; };
2420034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi  return ReduceWindow1DGeneric(operand, init, add_reduce, window, stride,
2430034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi                               padding);
2440034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi}
2450034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi
2466bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi/* static  */ std::unique_ptr<Array2D<float>> ReferenceUtil::ReduceWindow2DAdd(
2476bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi    const Array2D<float>& operand, float init,
2486bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi    const tensorflow::gtl::ArraySlice<int64>& window,
2496bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi    const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding) {
2506bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi  std::vector<int64> dim_lengths{operand.height(), operand.width()};
2516bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi  auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding);
2526bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi
2536bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi  std::vector<int64> window_counts(window.size(), 0);
2546bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi  std::vector<int64> pad_low(window.size(), 0);
2556bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi  for (int64 i = 0; i < window.size(); ++i) {
2566bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi    window_counts[i] =
2576bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi        WindowCount(dim_lengths[i], window[i], stride[i], padding);
2586bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi    pad_low[i] = padding_both[i].first;
2596bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi  }
2606bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi  auto result = MakeUnique<Array2D<float>>(window_counts[0], window_counts[1]);
2616bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi
2626bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi  // Do a full 2D reduce window.
2636bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi  for (int64 i0 = 0; i0 < window_counts[0]; ++i0) {
2646bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi    for (int64 i1 = 0; i1 < window_counts[1]; ++i1) {
2656bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi      int64 i0_base = i0 * stride[0] - pad_low[0];
2666bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi      int64 i1_base = i1 * stride[1] - pad_low[1];
2676bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi
2686bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi      float val = init;
2696bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi      for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) {
2706bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi        for (int64 i1_win = 0; i1_win < window[1]; ++i1_win) {
2716bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi          if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 &&
2726bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi              i0_base + i0_win < operand.n1() &&
2736bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi              i1_base + i1_win < operand.n2()) {
2746bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi            val += operand(i0_base + i0_win, i1_base + i1_win);
2756bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi          }
2766bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi        }
2776bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi      }
2786bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi      (*result)(i0, i1) = val;
2796bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi    }
2806bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi  }
2816bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi  return result;
2826bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi}
2836bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi
2847699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi/* static  */ std::unique_ptr<Array3D<float>> ReferenceUtil::ReduceWindow3DAdd(
2857699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi    const Array3D<float>& operand, float init,
2867699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi    const tensorflow::gtl::ArraySlice<int64>& window,
2877699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi    const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding) {
2887699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi  std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3()};
2897699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi  auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding);
2907699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi
2917699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi  std::vector<int64> window_counts(window.size(), 0);
2927699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi  std::vector<int64> pad_low(window.size(), 0);
2937699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi  for (int64 i = 0; i < window.size(); ++i) {
2947699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi    window_counts[i] =
2957699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi        WindowCount(dim_lengths[i], window[i], stride[i], padding);
2967699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi    pad_low[i] = padding_both[i].first;
2977699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi  }
2987699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi  auto result = MakeUnique<Array3D<float>>(window_counts[0], window_counts[1],
2997699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi                                           window_counts[2]);
3007699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi
3017699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi  for (int64 i0 = 0; i0 < window_counts[0]; ++i0) {
3027699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi    for (int64 i1 = 0; i1 < window_counts[1]; ++i1) {
3037699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi      for (int64 i2 = 0; i2 < window_counts[2]; ++i2) {
3047699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi        int64 i0_base = i0 * stride[0] - pad_low[0];
3057699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi        int64 i1_base = i1 * stride[1] - pad_low[1];
3067699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi        int64 i2_base = i2 * stride[2] - pad_low[2];
3077699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi
3087699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi        float val = init;
3097699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi        for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) {
3107699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi          for (int64 i1_win = 0; i1_win < window[1]; ++i1_win) {
3117699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi            for (int64 i2_win = 0; i2_win < window[2]; ++i2_win) {
3127699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi              if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 &&
3137699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi                  i2_base + i2_win >= 0 && i0_base + i0_win < operand.n1() &&
3147699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi                  i1_base + i1_win < operand.n2() &&
3157699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi                  i2_base + i2_win < operand.n3()) {
3167699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi                val += operand(i0_base + i0_win, i1_base + i1_win,
3177699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi                               i2_base + i2_win);
3187699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi              }
3197699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi            }
3207699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi          }
3217699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi        }
3227699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi        (*result)(i0, i1, i2) = val;
3237699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi      }
3247699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi    }
3257699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi  }
3267699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi  return result;
3277699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi}
3287699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi
3292d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi/* static */ std::unique_ptr<Array4D<float>>
3302d69270342d2a5e46446e02e9273e7da79f00accTayo OguntebiReferenceUtil::ReduceWindow4DGeneric(
3311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array4D<float>& operand, float init,
3322d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi    const std::function<float(float, float)>& reduce_func,
3331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const tensorflow::gtl::ArraySlice<int64>& window,
3341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding) {
3351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(),
3361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                 operand.n4()};
337f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune  return ReduceWindow4DGeneric(
338f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune      operand, init, reduce_func, window, stride,
339f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune      xla::MakePadding(dim_lengths, window, stride, padding));
340f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune}
341f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune
342f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune/* static */ std::unique_ptr<Array4D<float>>
343f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt RouneReferenceUtil::ReduceWindow4DGeneric(
344f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune    const Array4D<float>& operand, float init,
345f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune    const std::function<float(float, float)>& reduce_func,
346f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune    const tensorflow::gtl::ArraySlice<int64>& window,
347f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune    const tensorflow::gtl::ArraySlice<int64>& stride,
348f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune    const tensorflow::gtl::ArraySlice<std::pair<int64, int64>>& padding) {
349f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune  std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(),
350f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune                                 operand.n4()};
3511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
3521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  std::vector<int64> window_counts(window.size(), 0);
3531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  std::vector<int64> pad_low(window.size(), 0);
3541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  for (int64 i = 0; i < window.size(); ++i) {
355f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune    int64 padded_width = padding[i].first + dim_lengths[i] + padding[i].second;
3561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    window_counts[i] =
357f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune        window_util::StridedBound(padded_width, window[i], stride[i]);
358f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune    pad_low[i] = padding[i].first;
3591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
3601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto result = MakeUnique<Array4D<float>>(window_counts[0], window_counts[1],
3611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                           window_counts[2], window_counts[3]);
3621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // Do a full 4D reduce window.
3631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  for (int64 i0 = 0; i0 < window_counts[0]; ++i0) {
3641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    for (int64 i1 = 0; i1 < window_counts[1]; ++i1) {
3651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      for (int64 i2 = 0; i2 < window_counts[2]; ++i2) {
3661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        for (int64 i3 = 0; i3 < window_counts[3]; ++i3) {
3671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          int64 i0_base = i0 * stride[0] - pad_low[0];
3681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          int64 i1_base = i1 * stride[1] - pad_low[1];
3691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          int64 i2_base = i2 * stride[2] - pad_low[2];
3701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          int64 i3_base = i3 * stride[3] - pad_low[3];
3711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
3721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          float val = init;
3731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) {
3741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins            for (int64 i1_win = 0; i1_win < window[1]; ++i1_win) {
3751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins              for (int64 i2_win = 0; i2_win < window[2]; ++i2_win) {
3761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                for (int64 i3_win = 0; i3_win < window[3]; ++i3_win) {
3771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                  if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 &&
3781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      i2_base + i2_win >= 0 && i3_base + i3_win >= 0 &&
3791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      i0_base + i0_win < operand.n1() &&
3801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      i1_base + i1_win < operand.n2() &&
3811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      i2_base + i2_win < operand.n3() &&
3821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      i3_base + i3_win < operand.n4()) {
3832d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi                    val = reduce_func(
3842d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi                        val, operand(i0_base + i0_win, i1_base + i1_win,
3852d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi                                     i2_base + i2_win, i3_base + i3_win));
3861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                  }
3871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                }
3881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins              }
3891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins            }
3901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          }
3911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          (*result)(i0, i1, i2, i3) = val;
3921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        }
3931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      }
3941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
3951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
3961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return result;
3971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
3981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
3992d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi/* static  */ std::unique_ptr<Array4D<float>> ReferenceUtil::ReduceWindow4DAdd(
4002d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi    const Array4D<float>& operand, float init,
4012d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi    const tensorflow::gtl::ArraySlice<int64>& window,
4022d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi    const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding) {
4032d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi  const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; };
4042d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi  return ReduceWindow4DGeneric(operand, init, add_reduce, window, stride,
4052d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi                               padding);
4062d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi}
4072d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi
4081464b9930de871fd11870941963253670f737c23A. Unique TensorFlower/* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::BatchNorm4D(
4091464b9930de871fd11870941963253670f737c23A. Unique TensorFlower    const Array4D<float>& input, const Array4D<float>& mean,
4101464b9930de871fd11870941963253670f737c23A. Unique TensorFlower    const Array4D<float>& var, const Array4D<float>& scale,
4111464b9930de871fd11870941963253670f737c23A. Unique TensorFlower    const Array4D<float>& offset, float epsilon) {
4121464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto normalized =
4131464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      *MapArray4D(input, mean, [](float a, float b) { return a - b; });
4141464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  normalized = *MapArray4D(normalized, var, [&](float a, float b) {
4151464b9930de871fd11870941963253670f737c23A. Unique TensorFlower    return a / std::sqrt(b + epsilon);
4161464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  });
4171464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  normalized =
4181464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      *MapArray4D(normalized, scale, [](float a, float b) { return a * b; });
4191464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  return MapArray4D(normalized, offset, [](float a, float b) { return a + b; });
4201464b9930de871fd11870941963253670f737c23A. Unique TensorFlower}
4211464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
4221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static  */ std::unique_ptr<Array4D<float>>
4231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsReferenceUtil::SelectAndScatter4DGePlus(
4241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array4D<float>& operand, const Array4D<float>& source, float init,
4251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const tensorflow::gtl::ArraySlice<int64>& window,
4261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const tensorflow::gtl::ArraySlice<int64>& stride, bool same_padding) {
4271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  Padding padding = same_padding ? Padding::kSame : Padding::kValid;
4281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto result = MakeUnique<Array4D<float>>(operand.n1(), operand.n2(),
4291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                           operand.n3(), operand.n4());
4301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(),
4311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                 operand.n4()};
4321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding);
4331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // Fill the output, with the initial value.
4341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  result->Fill(init);
4351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
4361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  std::vector<int64> window_counts(window.size(), 0);
4371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  std::vector<int64> pad_low(window.size(), 0);
4381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  for (int64 i = 0; i < window.size(); ++i) {
4391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    window_counts[i] =
4401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        WindowCount(dim_lengths[i], window[i], stride[i], padding);
4411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    pad_low[i] = padding_both[i].first;
4421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
4431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  CHECK_EQ(window_counts[0], source.n1());
4441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  CHECK_EQ(window_counts[1], source.n2());
4451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  CHECK_EQ(window_counts[2], source.n3());
4461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  CHECK_EQ(window_counts[3], source.n4());
4471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
4481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // Do a full 4D select and Scatter.
4491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  for (int64 i0 = 0; i0 < window_counts[0]; ++i0) {
4501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    for (int64 i1 = 0; i1 < window_counts[1]; ++i1) {
4511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      for (int64 i2 = 0; i2 < window_counts[2]; ++i2) {
4521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        for (int64 i3 = 0; i3 < window_counts[3]; ++i3) {
4531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          // Now we are inside a window and need to find the max and the argmax.
4541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          int64 i0_base = i0 * stride[0] - pad_low[0];
4551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          int64 i1_base = i1 * stride[1] - pad_low[1];
4561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          int64 i2_base = i2 * stride[2] - pad_low[2];
4571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          int64 i3_base = i3 * stride[3] - pad_low[3];
4581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          int64 scatter_0 = (i0_base >= 0) ? i0_base : 0;
4591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          int64 scatter_1 = (i1_base >= 0) ? i1_base : 0;
4601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          int64 scatter_2 = (i2_base >= 0) ? i2_base : 0;
4611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          int64 scatter_3 = (i3_base >= 0) ? i3_base : 0;
4621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          float val = operand(scatter_0, scatter_1, scatter_2, scatter_3);
4631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) {
4641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins            for (int64 i1_win = 0; i1_win < window[1]; ++i1_win) {
4651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins              for (int64 i2_win = 0; i2_win < window[2]; ++i2_win) {
4661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                for (int64 i3_win = 0; i3_win < window[3]; ++i3_win) {
4671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                  if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 &&
4681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      i2_base + i2_win >= 0 && i3_base + i3_win >= 0 &&
4691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      i0_base + i0_win < operand.n1() &&
4701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      i1_base + i1_win < operand.n2() &&
4711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      i2_base + i2_win < operand.n3() &&
4721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      i3_base + i3_win < operand.n4()) {
4731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                    float tmp = operand(i0_base + i0_win, i1_base + i1_win,
4741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                        i2_base + i2_win, i3_base + i3_win);
4751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                    if (tmp >= val) {
4761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      val = tmp;
4771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      scatter_0 = i0_base + i0_win;
4781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      scatter_1 = i1_base + i1_win;
4791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      scatter_2 = i2_base + i2_win;
4801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      scatter_3 = i3_base + i3_win;
4811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                    }
4821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                  }
4831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                }
4841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins              }
4851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins            }
4861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          }
4871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          (*result)(scatter_0, scatter_1, scatter_2, scatter_3) +=
4881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins              source(i0, i1, i2, i3);
4891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        }
4901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      }
4911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
4921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
4931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return result;
4941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
4951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
4961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array4D<float>>
4971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsReferenceUtil::ConvArray4DGeneralDimensions(
4981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array4D<float>& lhs, const Array4D<float>& rhs,
4991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    std::pair<int64, int64> kernel_stride, Padding padding,
5001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    ConvolutionDimensionNumbers dimension_numbers) {
5011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return ConvArray4DGeneralDimensionsDilated(lhs, rhs, kernel_stride, padding,
502f4b8d21b8e41636b6e61f0a1de753430108d2ee7A. Unique TensorFlower                                             {1, 1}, {1, 1},
503f4b8d21b8e41636b6e61f0a1de753430108d2ee7A. Unique TensorFlower                                             std::move(dimension_numbers));
5041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
5051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
5061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array4D<float>>
5071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsReferenceUtil::ConvArray4DGeneralDimensionsDilated(
5081e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array4D<float>& lhs, const Array4D<float>& rhs,
5091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    std::pair<int64, int64> kernel_stride, Padding padding,
5101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    std::pair<int64, int64> lhs_dilation, std::pair<int64, int64> rhs_dilation,
5111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    ConvolutionDimensionNumbers dnums) {
512253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  HloComputation::Builder b("ConvArray4DGeneralDimensionDilated");
513253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  auto lhs_literal = Literal::CreateR4FromArray4D<float>(lhs);
514253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  auto rhs_literal = Literal::CreateR4FromArray4D<float>(rhs);
515253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu
516253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  std::array<int64, 2> ordered_kernel_strides;
517253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  std::array<int64, 2> ordered_input_dimensions;
518253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  std::array<int64, 2> ordered_kernel_dimensions;
519253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  if (dnums.kernel_spatial_dimensions(0) > dnums.kernel_spatial_dimensions(1)) {
520253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu    ordered_kernel_strides[0] = kernel_stride.second;
521253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu    ordered_kernel_strides[1] = kernel_stride.first;
522253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  } else {
523253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu    ordered_kernel_strides[0] = kernel_stride.first;
524253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu    ordered_kernel_strides[1] = kernel_stride.second;
5251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
5261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
527253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  ordered_input_dimensions[0] =
528102bfdfd830f4dab6e00371e63a82561e1246518David Majnemer      lhs_literal->shape().dimensions(dnums.input_spatial_dimensions(0));
529253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  ordered_input_dimensions[1] =
530102bfdfd830f4dab6e00371e63a82561e1246518David Majnemer      lhs_literal->shape().dimensions(dnums.input_spatial_dimensions(1));
531253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  ordered_kernel_dimensions[0] =
532253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu      rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(0));
533253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  ordered_kernel_dimensions[1] =
534253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu      rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(1));
535253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu
536253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  std::vector<std::pair<int64, int64>> paddings =
537253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu      MakePadding(ordered_input_dimensions, ordered_kernel_dimensions,
538253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu                  ordered_kernel_strides, padding);
539253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  CHECK_EQ(paddings.size(), 2);
540253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu
541253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  Window window;
542253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu
543253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  WindowDimension dim;
544253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  dim.set_size(
545253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu      rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(0)));
546253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  dim.set_stride(kernel_stride.first);
547253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  dim.set_padding_low(paddings[0].first);
548253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  dim.set_padding_high(paddings[0].second);
549253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  dim.set_window_dilation(rhs_dilation.first);
550253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  dim.set_base_dilation(lhs_dilation.first);
551253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  *window.add_dimensions() = dim;
552253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu
553253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  WindowDimension dim2;
554253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  dim2.set_size(
555253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu      rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(1)));
556253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  dim2.set_stride(kernel_stride.second);
557253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  dim2.set_padding_low(paddings[1].first);
558253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  dim2.set_padding_high(paddings[1].second);
559253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  dim2.set_window_dilation(rhs_dilation.second);
560253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  dim2.set_base_dilation(lhs_dilation.second);
561253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  *window.add_dimensions() = dim2;
562253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu
563253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  const Shape& shape =
564253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu      ShapeInference::InferConvolveShape(lhs_literal->shape(),
565253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu                                         rhs_literal->shape(), window, dnums)
566253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu          .ConsumeValueOrDie();
567253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu
568253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  HloInstruction* lhs_instruction =
569253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu      b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
570253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  HloInstruction* rhs_instruction =
571253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu      b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
572253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu
573253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  b.AddInstruction(HloInstruction::CreateConvolve(
574253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu      shape, lhs_instruction, rhs_instruction, window, dnums));
575b1c10555afe9ad4ebebbd83eb31dbf8006d7980bMark Heffernan  HloModule module("ReferenceUtil");
576b1c10555afe9ad4ebebbd83eb31dbf8006d7980bMark Heffernan  auto computation = module.AddEntryComputation(b.Build());
577253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu
578253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  HloEvaluator evaluator;
579253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  std::unique_ptr<Literal> result_literal =
580713d45278491d792c525344de6038a61ebcb2136Kay Zhu      evaluator.Evaluate<const Literal*>(*computation, {}).ConsumeValueOrDie();
581253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu
582253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  CHECK_EQ(ShapeUtil::Rank(result_literal->shape()), 4);
583b01346de8b5893a09d50ff4d9c80ca442a327a76Kay Zhu  auto result =
584253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu      MakeUnique<Array4D<float>>(result_literal->shape().dimensions(0),
585253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu                                 result_literal->shape().dimensions(1),
586253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu                                 result_literal->shape().dimensions(2),
587253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu                                 result_literal->shape().dimensions(3));
588253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu
589253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  result->Each([&](tensorflow::gtl::ArraySlice<int64> indices, float* value) {
590253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu    *value = result_literal->Get<float>(indices);
591253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu  });
5921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
5931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return result;
5941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
5951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
5961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<std::vector<float>>
5971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsReferenceUtil::ReduceToColArray2D(
5981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array2D<float>& matrix, float init,
5997e08b5c7ae5a10810948095b96d2da53e816e446Eli Bendersky    const std::function<float(float, float)>& reduce_function) {
6001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 rows = matrix.height();
6011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 cols = matrix.width();
6021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto result = MakeUnique<std::vector<float>>();
6031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  for (int64 i = 0; i < rows; ++i) {
6041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    float acc = init;
6051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    for (int64 j = 0; j < cols; ++j) {
6061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      acc = reduce_function(acc, matrix(i, j));
6071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
6081e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    result->push_back(acc);
6091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
6101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return result;
6111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
6121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
6131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<std::vector<float>>
6141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsReferenceUtil::ReduceToRowArray2D(
6151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array2D<float>& matrix, float init,
6167e08b5c7ae5a10810948095b96d2da53e816e446Eli Bendersky    const std::function<float(float, float)>& reduce_function) {
6171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 rows = matrix.height();
6181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 cols = matrix.width();
6191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto result = MakeUnique<std::vector<float>>();
6201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  for (int64 i = 0; i < cols; ++i) {
6211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    float acc = init;
6221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    for (int64 j = 0; j < rows; ++j) {
6231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      acc = reduce_function(acc, matrix(j, i));
6241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
6251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    result->push_back(acc);
6261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
6271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return result;
6281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
6291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
6301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/*static*/ std::vector<float> ReferenceUtil::Reduce4DTo1D(
6311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array4D<float>& array, float init,
6321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    tensorflow::gtl::ArraySlice<int64> dims,
6337e08b5c7ae5a10810948095b96d2da53e816e446Eli Bendersky    const std::function<float(float, float)>& reduce_function) {
6341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  std::vector<float> result;
6351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  CHECK_EQ(dims.size(), 3);
6361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  const std::set<int64> dim_set(dims.begin(), dims.end());
6371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  CHECK_EQ(dim_set.size(), 3);
6381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  for (int64 a0 = 0; a0 == 0 || (!dim_set.count(0) && a0 < array.n1()); ++a0) {
6391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    for (int64 a1 = 0; a1 == 0 || (!dim_set.count(1) && a1 < array.n2());
6401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins         ++a1) {
6411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      for (int64 a2 = 0; a2 == 0 || (!dim_set.count(2) && a2 < array.n3());
6421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins           ++a2) {
6431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        for (int64 a3 = 0; a3 == 0 || (!dim_set.count(3) && a3 < array.n4());
6441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins             ++a3) {
6451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          float accumulator = init;
6461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          for (int64 i0 = 0; i0 == 0 || (dim_set.count(0) && i0 < array.n1());
6471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins               ++i0) {
6481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins            for (int64 i1 = 0; i1 == 0 || (dim_set.count(1) && i1 < array.n2());
6491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                 ++i1) {
6501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins              for (int64 i2 = 0;
6511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                   i2 == 0 || (dim_set.count(2) && i2 < array.n3()); ++i2) {
6521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                for (int64 i3 = 0;
6531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                     i3 == 0 || (dim_set.count(3) && i3 < array.n4()); ++i3) {
6541654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar                  // Handle zero-sized arrays.
6551654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar                  if (array.n1() > 0 && array.n2() > 0 && array.n3() > 0 &&
6561654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar                      array.n4() > 0) {
6571654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar                    accumulator = reduce_function(
6581654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar                        accumulator, array(a0 + i0, a1 + i1, a2 + i2, a3 + i3));
6591654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar                  }
6601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                }
6611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins              }
6621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins            }
6631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          }
6641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          result.push_back(accumulator);
6651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        }
6661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      }
6671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
6681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
6691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return result;
6701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
6711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
6721464b9930de871fd11870941963253670f737c23A. Unique TensorFlower/* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::Broadcast1DTo4D(
6731464b9930de871fd11870941963253670f737c23A. Unique TensorFlower    const std::vector<float>& array, const std::vector<int64>& bounds,
6741464b9930de871fd11870941963253670f737c23A. Unique TensorFlower    int64 broadcast_from_dim) {
6751464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto result =
6761464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      MakeUnique<Array4D<float>>(bounds[0], bounds[1], bounds[2], bounds[3]);
6771464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  for (int64 i = 0; i < result->n1(); ++i) {
6781464b9930de871fd11870941963253670f737c23A. Unique TensorFlower    for (int64 j = 0; j < result->n2(); ++j) {
6791464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      for (int64 k = 0; k < result->n3(); ++k) {
6801464b9930de871fd11870941963253670f737c23A. Unique TensorFlower        for (int64 l = 0; l < result->n4(); ++l) {
6811464b9930de871fd11870941963253670f737c23A. Unique TensorFlower          switch (broadcast_from_dim) {
6821464b9930de871fd11870941963253670f737c23A. Unique TensorFlower            case 0:
6831464b9930de871fd11870941963253670f737c23A. Unique TensorFlower              (*result)(i, j, k, l) = array[i];
6841464b9930de871fd11870941963253670f737c23A. Unique TensorFlower              break;
6851464b9930de871fd11870941963253670f737c23A. Unique TensorFlower            case 1:
6861464b9930de871fd11870941963253670f737c23A. Unique TensorFlower              (*result)(i, j, k, l) = array[j];
6871464b9930de871fd11870941963253670f737c23A. Unique TensorFlower              break;
6881464b9930de871fd11870941963253670f737c23A. Unique TensorFlower            case 2:
6891464b9930de871fd11870941963253670f737c23A. Unique TensorFlower              (*result)(i, j, k, l) = array[k];
6901464b9930de871fd11870941963253670f737c23A. Unique TensorFlower              break;
6911464b9930de871fd11870941963253670f737c23A. Unique TensorFlower            case 3:
6921464b9930de871fd11870941963253670f737c23A. Unique TensorFlower              (*result)(i, j, k, l) = array[l];
6931464b9930de871fd11870941963253670f737c23A. Unique TensorFlower              break;
6941464b9930de871fd11870941963253670f737c23A. Unique TensorFlower            default:
6951464b9930de871fd11870941963253670f737c23A. Unique TensorFlower              break;
6961464b9930de871fd11870941963253670f737c23A. Unique TensorFlower          }
6971464b9930de871fd11870941963253670f737c23A. Unique TensorFlower        }
6981464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      }
6991464b9930de871fd11870941963253670f737c23A. Unique TensorFlower    }
7001464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  }
7011464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  return result;
7021464b9930de871fd11870941963253670f737c23A. Unique TensorFlower}
7031464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
7041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::Reduce3DTo2D(
7051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array3D<float>& array, float init,
7061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    tensorflow::gtl::ArraySlice<int64> dims,
7077e08b5c7ae5a10810948095b96d2da53e816e446Eli Bendersky    const std::function<float(float, float)>& reduce_function) {
7081e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  CHECK_EQ(dims.size(), 1);
7091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 rows = dims[0] == 0 ? array.n2() : array.n1();
7101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 cols = dims[0] == 2 ? array.n2() : array.n3();
7111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto result = MakeUnique<Array2D<float>>(rows, cols);
7121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  result->Fill(init);
7131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  for (int i0 = 0; i0 < array.n1(); ++i0) {
7141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    for (int i1 = 0; i1 < array.n2(); ++i1) {
7151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      for (int i2 = 0; i2 < array.n3(); ++i2) {
7161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        int64 row = dims[0] == 0 ? i1 : i0;
7171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        int64 col = dims[0] == 2 ? i1 : i2;
7181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        (*result)(row, col) =
7191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins            reduce_function((*result)(row, col), array(i0, i1, i2));
7201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      }
7211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
7221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
7231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return result;
7241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
7251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
7261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MapArray2D(
7271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array2D<float>& matrix,
7281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const std::function<float(float)>& map_function) {
7291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 rows = matrix.height();
7301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 cols = matrix.width();
7311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto result = MakeUnique<Array2D<float>>(rows, cols);
7321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  for (int64 i = 0; i < rows; ++i) {
7331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    for (int64 j = 0; j < cols; ++j) {
7341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      (*result)(i, j) = map_function(matrix(i, j));
7351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
7361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
7371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return result;
7381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
7391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
7401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MapArray2D(
7411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array2D<float>& lhs, const Array2D<float>& rhs,
7421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const std::function<float(float, float)>& map_function) {
7431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  CHECK_EQ(lhs.height(), rhs.height());
7441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  CHECK_EQ(lhs.width(), rhs.width());
7451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 rows = lhs.height();
7461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 cols = rhs.width();
7471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto result = MakeUnique<Array2D<float>>(rows, cols);
7481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  for (int64 i = 0; i < rows; ++i) {
7491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    for (int64 j = 0; j < cols; ++j) {
7501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      (*result)(i, j) = map_function(lhs(i, j), rhs(i, j));
7511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
7521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
7531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return result;
7541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
7551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
7561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MapWithIndexArray2D(
7571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const Array2D<float>& matrix,
7581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const std::function<float(float, int64, int64)>& map_function) {
7591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 rows = matrix.height();
7601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 cols = matrix.width();
7611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto result = MakeUnique<Array2D<float>>(rows, cols);
7621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  for (int64 i = 0; i < rows; ++i) {
7631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    for (int64 j = 0; j < cols; ++j) {
7641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      (*result)(i, j) = map_function(matrix(i, j), i, j);
7651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
7661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
7671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return result;
7681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
7691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
7701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}  // namespace xla
771