11e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 21e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 31e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsLicensed under the Apache License, Version 2.0 (the "License"); 41e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsyou may not use this file except in compliance with the License. 51e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsYou may obtain a copy of the License at 61e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 71e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins http://www.apache.org/licenses/LICENSE-2.0 81e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 91e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsUnless required by applicable law or agreed to in writing, software 101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsdistributed under the License is distributed on an "AS IS" BASIS, 111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsSee the License for the specific language governing permissions and 131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinslimitations under the License. 141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins==============================================================================*/ 151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/reference_util.h" 171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include <array> 19f4b8d21b8e41636b6e61f0a1de753430108d2ee7A. Unique TensorFlower#include <utility> 201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/client/computation_builder.h" 221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" 23253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu#include "tensorflow/compiler/xla/service/hlo_evaluator.h" 24253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu#include "tensorflow/compiler/xla/service/hlo_instruction.h" 25253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu#include "tensorflow/compiler/xla/service/shape_inference.h" 261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/window_util.h" 271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/xla_data.pb.h" 281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/lib/math/math_util.h" 291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/platform/logging.h" 301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsnamespace xla { 321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::TransposeArray2D( 341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const Array2D<float>& operand) { 351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto result = MakeUnique<Array2D<float>>(operand.width(), operand.height()); 361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 w = 0; w < operand.width(); ++w) { 371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 h = 0; h < operand.height(); ++h) { 381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins (*result)(w, h) = operand(h, w); 391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return result; 431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MatmulArray2D( 461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const Array2D<float>& lhs, const Array2D<float>& rhs) { 471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins CHECK_EQ(lhs.width(), rhs.height()); 481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int m = lhs.height(); 491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int n = rhs.width(); 501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int k = lhs.width(); 511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto result = MakeUnique<Array2D<float>>(m, n); 521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // Because Eigen is a header-oriented library, make sure that the Eigen code 531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // is the same as the code used by the CPU backend (otherwise the linker will 541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // randomly pick *some* definition). 551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins __xla_cpu_runtime_EigenSingleThreadedMatMulF32( 561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins /*run_options_ptr=*/nullptr, result->data(), rhs.data(), lhs.data(), n, m, 571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins k, 581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins /*transpose_lhs=*/0, 591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins /*transpose_rhs=*/0); 601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return result; 611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<double>> ReferenceUtil::MatmulArray2D( 641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const Array2D<double>& lhs, const Array2D<double>& rhs) { 651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins CHECK_EQ(lhs.width(), rhs.height()); 661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int m = lhs.height(); 671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int n = rhs.width(); 681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int k = lhs.width(); 691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto result = MakeUnique<Array2D<double>>(m, n); 701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // Because Eigen is a header-oriented library, make sure that the Eigen code 711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // is the same as the code used by the CPU backend (otherwise the linker will 721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // randomly pick *some* definition). 731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins __xla_cpu_runtime_EigenSingleThreadedMatMulF64( 741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins /*run_options_ptr=*/nullptr, result->data(), rhs.data(), lhs.data(), n, m, 751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins k, 761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins /*transpose_lhs=*/0, 771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins /*transpose_rhs=*/0); 781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return result; 791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<double>> ReferenceUtil::Array2DF32ToF64( 821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const Array2D<float>& input) { 831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto result = MakeUnique<Array2D<double>>(input.height(), input.width()); 841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 rowno = 0; rowno < input.height(); ++rowno) { 851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 colno = 0; colno < input.height(); ++colno) { 861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins (*result)(rowno, colno) = input(rowno, colno); 871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return result; 901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 929b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower/* static */ std::unique_ptr<Array3D<float>> ReferenceUtil::ConvArray3D( 939b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower const Array3D<float>& lhs, const Array3D<float>& rhs, int64 kernel_stride, 949b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower Padding padding) { 959b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower return ConvArray3DGeneralDimensionsDilated( 969b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower lhs, rhs, kernel_stride, padding, 1, 1, 979b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower ComputationBuilder::CreateDefaultConvDimensionNumbers(1)); 989b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower} 999b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower 1009b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower/*static*/ std::unique_ptr<Array3D<float>> 1019b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlowerReferenceUtil::ConvArray3DGeneralDimensionsDilated( 1029b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower const Array3D<float>& lhs, const Array3D<float>& rhs, int64 kernel_stride, 1039b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower Padding padding, int64 lhs_dilation, int64 rhs_dilation, 1049b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower const ConvolutionDimensionNumbers& dnums) { 105102bfdfd830f4dab6e00371e63a82561e1246518David Majnemer CHECK_EQ(dnums.input_spatial_dimensions_size(), 1); 106102bfdfd830f4dab6e00371e63a82561e1246518David Majnemer CHECK_EQ(dnums.kernel_spatial_dimensions_size(), 1); 107102bfdfd830f4dab6e00371e63a82561e1246518David Majnemer CHECK_EQ(dnums.output_spatial_dimensions_size(), 1); 1089b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower // Reuse the code for Array4D-convolution by extending the 3D input into a 4D 1099b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower // array by adding a fourth dummy dimension of size 1 without stride, padding 1109b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower // and dilation. 1119b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower Array4D<float> a4dlhs(lhs.n1(), lhs.n2(), lhs.n3(), 1); 1129b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower a4dlhs.Each( 1139b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower [&](tensorflow::gtl::ArraySlice<int64> indices, float* value_ptr) { 1149b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower CHECK_EQ(indices[3], 0); 1159b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower *value_ptr = lhs.operator()(indices[0], indices[1], indices[2]); 1169b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower }); 1179b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower Array4D<float> a4drhs(rhs.n1(), rhs.n2(), rhs.n3(), 1); 1189b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower a4drhs.Each( 1199b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower [&](tensorflow::gtl::ArraySlice<int64> indices, float* value_ptr) { 1209b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower CHECK_EQ(indices[3], 0); 1219b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower *value_ptr = rhs.operator()(indices[0], indices[1], indices[2]); 1229b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower }); 1239b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower // Add a second dummy spatial dimensions. 1249b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower ConvolutionDimensionNumbers dnums2d = dnums; 125102bfdfd830f4dab6e00371e63a82561e1246518David Majnemer dnums2d.add_input_spatial_dimensions(3); 1269b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower dnums2d.add_kernel_spatial_dimensions(3); 127102bfdfd830f4dab6e00371e63a82561e1246518David Majnemer dnums2d.add_output_spatial_dimensions(3); 1289b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower std::unique_ptr<Array4D<float>> convr4 = ConvArray4DGeneralDimensionsDilated( 1299b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower a4dlhs, a4drhs, {kernel_stride, 1}, padding, {lhs_dilation, 1}, 1309b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower {rhs_dilation, 1}, dnums2d); 1319b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower 1329b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower auto convr3 = MakeUnique<Array3D<float>>(convr4->planes(), convr4->depth(), 1339b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower convr4->height()); 1349b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower convr4->Each( 1359b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower [&](tensorflow::gtl::ArraySlice<int64> indices, float* value_ptr) { 1369b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower CHECK_EQ(indices[3], 0); 1379b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower convr3->operator()(indices[0], indices[1], indices[2]) = *value_ptr; 1389b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower }); 1399b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower return convr3; 1409b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower} 1419b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower 1421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::ConvArray4D( 1431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const Array4D<float>& lhs, const Array4D<float>& rhs, 1441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::pair<int64, int64> kernel_stride, Padding padding) { 1451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return ConvArray4DGeneralDimensions( 1461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins lhs, rhs, kernel_stride, padding, 1471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ComputationBuilder::CreateDefaultConvDimensionNumbers()); 1481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 1491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1507fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee/* static */ std::unique_ptr<Array4D<float>> 1517fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong LeeReferenceUtil::SeparableConvArray4D(const Array4D<float>& input, 1527fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee const Array4D<float>& depthwise_weights, 1537fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee const Array4D<float>& pointwise_weights, 1547fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee std::pair<int64, int64> kernel_stride, 1557fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee Padding padding) { 1567fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee const int64 depth_multiplier = depthwise_weights.planes(); 1577fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee CHECK_EQ(pointwise_weights.depth(), input.depth() * depth_multiplier); 1587fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee 1597fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee // Combine the two weights by reducing the depth_multiplier, so that we can 1607fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee // apply a single convolution on the combined weights. 1617fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee Array4D<float> weights(pointwise_weights.planes(), input.depth(), 1627fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee depthwise_weights.height(), depthwise_weights.width()); 1637fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee for (int64 kx = 0; kx < depthwise_weights.width(); ++kx) { 1647fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee for (int64 ky = 0; ky < depthwise_weights.height(); ++ky) { 1657fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee for (int64 kz = 0; kz < input.depth(); ++kz) { 1667fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee for (int64 out = 0; out < pointwise_weights.planes(); ++out) { 1677fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee float weight = 0.0; 1687fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee for (int64 depth = 0; depth < depth_multiplier; ++depth) { 1697fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee weight += 1707fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee depthwise_weights(depth, kz, ky, kx) * 1717fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee pointwise_weights(out, depth + kz * depth_multiplier, 0, 0); 1727fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee } 1737fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee weights(out, kz, ky, kx) = weight; 1747fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee } 1757fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee } 1767fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee } 1777fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee } 1787fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee 1797fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee return ConvArray4D(input, weights, kernel_stride, padding); 1807fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee} 1817fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee 1821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ int64 ReferenceUtil::WindowCount(int64 unpadded_width, 1831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 window_len, int64 stride, 1841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins Padding padding) { 1851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins if (padding == Padding::kValid) { 1861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return window_util::StridedBound(unpadded_width, window_len, stride); 1871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 1881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return tensorflow::MathUtil::CeilOfRatio(unpadded_width, stride); 1891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 1901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1910034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi/* static */ std::unique_ptr<std::vector<float>> 1920034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo OguntebiReferenceUtil::ReduceWindow1DGeneric( 1930034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi const tensorflow::gtl::ArraySlice<float>& operand, float init, 1940034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi const std::function<float(float, float)>& reduce_func, 1950034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi const tensorflow::gtl::ArraySlice<int64>& window, 1960034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding) { 1970034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi std::vector<int64> dim_lengths{static_cast<int64>(operand.size())}; 198e2f9107effb0c5c4cee49a71562865d9e919b3d0Tayo Oguntebi return ReduceWindow1DGeneric( 199e2f9107effb0c5c4cee49a71562865d9e919b3d0Tayo Oguntebi operand, init, reduce_func, window, stride, 200e2f9107effb0c5c4cee49a71562865d9e919b3d0Tayo Oguntebi xla::MakePadding(dim_lengths, window, stride, padding)); 201e2f9107effb0c5c4cee49a71562865d9e919b3d0Tayo Oguntebi} 2020034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi 203e2f9107effb0c5c4cee49a71562865d9e919b3d0Tayo Oguntebi/* static */ std::unique_ptr<std::vector<float>> 204e2f9107effb0c5c4cee49a71562865d9e919b3d0Tayo OguntebiReferenceUtil::ReduceWindow1DGeneric( 205e2f9107effb0c5c4cee49a71562865d9e919b3d0Tayo Oguntebi const tensorflow::gtl::ArraySlice<float>& operand, float init, 206e2f9107effb0c5c4cee49a71562865d9e919b3d0Tayo Oguntebi const std::function<float(float, float)>& reduce_func, 207e2f9107effb0c5c4cee49a71562865d9e919b3d0Tayo Oguntebi const tensorflow::gtl::ArraySlice<int64>& window, 208e2f9107effb0c5c4cee49a71562865d9e919b3d0Tayo Oguntebi const tensorflow::gtl::ArraySlice<int64>& stride, 209e2f9107effb0c5c4cee49a71562865d9e919b3d0Tayo Oguntebi const tensorflow::gtl::ArraySlice<std::pair<int64, int64>>& padding) { 210e2f9107effb0c5c4cee49a71562865d9e919b3d0Tayo Oguntebi std::vector<int64> dim_lengths{static_cast<int64>(operand.size())}; 2110034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi std::vector<int64> window_counts(window.size(), 0); 2120034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi std::vector<int64> pad_low(window.size(), 0); 2130034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi for (int64 i = 0; i < window.size(); ++i) { 214e2f9107effb0c5c4cee49a71562865d9e919b3d0Tayo Oguntebi int64 padded_width = padding[i].first + dim_lengths[i] + padding[i].second; 2150034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi window_counts[i] = 216e2f9107effb0c5c4cee49a71562865d9e919b3d0Tayo Oguntebi window_util::StridedBound(padded_width, window[i], stride[i]); 217e2f9107effb0c5c4cee49a71562865d9e919b3d0Tayo Oguntebi pad_low[i] = padding[i].first; 2180034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi } 2190034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi auto result = MakeUnique<std::vector<float>>(window_counts[0]); 2200034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi 2210034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi // Do a full 1D reduce window. 2220034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi for (int64 i0 = 0; i0 < window_counts[0]; ++i0) { 2230034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi int64 i0_base = i0 * stride[0] - pad_low[0]; 2240034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi 2250034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi float val = init; 2260034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) { 2270034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi if (i0_base + i0_win >= 0 && i0_base + i0_win < dim_lengths[0]) { 2280034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi val = reduce_func(val, operand[i0_base + i0_win]); 2290034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi } 2300034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi } 2310034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi (*result)[i0] = val; 2320034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi } 2330034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi return result; 2340034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi} 2350034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi 2360034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi/* static */ std::unique_ptr<std::vector<float>> 2370034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo OguntebiReferenceUtil::ReduceWindow1DAdd( 2380034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi const tensorflow::gtl::ArraySlice<float>& operand, float init, 2390034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi const tensorflow::gtl::ArraySlice<int64>& window, 2400034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding) { 2410034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; }; 2420034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi return ReduceWindow1DGeneric(operand, init, add_reduce, window, stride, 2430034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi padding); 2440034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi} 2450034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi 2466bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::ReduceWindow2DAdd( 2476bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi const Array2D<float>& operand, float init, 2486bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi const tensorflow::gtl::ArraySlice<int64>& window, 2496bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding) { 2506bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi std::vector<int64> dim_lengths{operand.height(), operand.width()}; 2516bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding); 2526bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi 2536bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi std::vector<int64> window_counts(window.size(), 0); 2546bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi std::vector<int64> pad_low(window.size(), 0); 2556bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi for (int64 i = 0; i < window.size(); ++i) { 2566bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi window_counts[i] = 2576bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi WindowCount(dim_lengths[i], window[i], stride[i], padding); 2586bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi pad_low[i] = padding_both[i].first; 2596bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi } 2606bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi auto result = MakeUnique<Array2D<float>>(window_counts[0], window_counts[1]); 2616bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi 2626bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi // Do a full 2D reduce window. 2636bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi for (int64 i0 = 0; i0 < window_counts[0]; ++i0) { 2646bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi for (int64 i1 = 0; i1 < window_counts[1]; ++i1) { 2656bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi int64 i0_base = i0 * stride[0] - pad_low[0]; 2666bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi int64 i1_base = i1 * stride[1] - pad_low[1]; 2676bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi 2686bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi float val = init; 2696bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) { 2706bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi for (int64 i1_win = 0; i1_win < window[1]; ++i1_win) { 2716bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 && 2726bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi i0_base + i0_win < operand.n1() && 2736bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi i1_base + i1_win < operand.n2()) { 2746bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi val += operand(i0_base + i0_win, i1_base + i1_win); 2756bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi } 2766bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi } 2776bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi } 2786bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi (*result)(i0, i1) = val; 2796bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi } 2806bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi } 2816bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi return result; 2826bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi} 2836bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi 2847699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi/* static */ std::unique_ptr<Array3D<float>> ReferenceUtil::ReduceWindow3DAdd( 2857699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi const Array3D<float>& operand, float init, 2867699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi const tensorflow::gtl::ArraySlice<int64>& window, 2877699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding) { 2887699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3()}; 2897699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding); 2907699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi 2917699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi std::vector<int64> window_counts(window.size(), 0); 2927699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi std::vector<int64> pad_low(window.size(), 0); 2937699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi for (int64 i = 0; i < window.size(); ++i) { 2947699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi window_counts[i] = 2957699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi WindowCount(dim_lengths[i], window[i], stride[i], padding); 2967699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi pad_low[i] = padding_both[i].first; 2977699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi } 2987699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi auto result = MakeUnique<Array3D<float>>(window_counts[0], window_counts[1], 2997699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi window_counts[2]); 3007699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi 3017699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi for (int64 i0 = 0; i0 < window_counts[0]; ++i0) { 3027699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi for (int64 i1 = 0; i1 < window_counts[1]; ++i1) { 3037699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi for (int64 i2 = 0; i2 < window_counts[2]; ++i2) { 3047699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi int64 i0_base = i0 * stride[0] - pad_low[0]; 3057699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi int64 i1_base = i1 * stride[1] - pad_low[1]; 3067699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi int64 i2_base = i2 * stride[2] - pad_low[2]; 3077699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi 3087699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi float val = init; 3097699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) { 3107699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi for (int64 i1_win = 0; i1_win < window[1]; ++i1_win) { 3117699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi for (int64 i2_win = 0; i2_win < window[2]; ++i2_win) { 3127699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 && 3137699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi i2_base + i2_win >= 0 && i0_base + i0_win < operand.n1() && 3147699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi i1_base + i1_win < operand.n2() && 3157699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi i2_base + i2_win < operand.n3()) { 3167699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi val += operand(i0_base + i0_win, i1_base + i1_win, 3177699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi i2_base + i2_win); 3187699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi } 3197699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi } 3207699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi } 3217699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi } 3227699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi (*result)(i0, i1, i2) = val; 3237699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi } 3247699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi } 3257699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi } 3267699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi return result; 3277699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi} 3287699ea8bee2ee9d0d5c93706cf1357a36c60ef58Tayo Oguntebi 3292d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi/* static */ std::unique_ptr<Array4D<float>> 3302d69270342d2a5e46446e02e9273e7da79f00accTayo OguntebiReferenceUtil::ReduceWindow4DGeneric( 3311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const Array4D<float>& operand, float init, 3322d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi const std::function<float(float, float)>& reduce_func, 3331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const tensorflow::gtl::ArraySlice<int64>& window, 3341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding) { 3351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(), 3361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins operand.n4()}; 337f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune return ReduceWindow4DGeneric( 338f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune operand, init, reduce_func, window, stride, 339f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune xla::MakePadding(dim_lengths, window, stride, padding)); 340f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune} 341f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune 342f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune/* static */ std::unique_ptr<Array4D<float>> 343f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt RouneReferenceUtil::ReduceWindow4DGeneric( 344f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune const Array4D<float>& operand, float init, 345f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune const std::function<float(float, float)>& reduce_func, 346f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune const tensorflow::gtl::ArraySlice<int64>& window, 347f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune const tensorflow::gtl::ArraySlice<int64>& stride, 348f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune const tensorflow::gtl::ArraySlice<std::pair<int64, int64>>& padding) { 349f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(), 350f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune operand.n4()}; 3511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 3521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::vector<int64> window_counts(window.size(), 0); 3531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::vector<int64> pad_low(window.size(), 0); 3541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i = 0; i < window.size(); ++i) { 355f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune int64 padded_width = padding[i].first + dim_lengths[i] + padding[i].second; 3561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins window_counts[i] = 357f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune window_util::StridedBound(padded_width, window[i], stride[i]); 358f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune pad_low[i] = padding[i].first; 3591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 3601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto result = MakeUnique<Array4D<float>>(window_counts[0], window_counts[1], 3611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins window_counts[2], window_counts[3]); 3621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // Do a full 4D reduce window. 3631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i0 = 0; i0 < window_counts[0]; ++i0) { 3641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i1 = 0; i1 < window_counts[1]; ++i1) { 3651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i2 = 0; i2 < window_counts[2]; ++i2) { 3661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i3 = 0; i3 < window_counts[3]; ++i3) { 3671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 i0_base = i0 * stride[0] - pad_low[0]; 3681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 i1_base = i1 * stride[1] - pad_low[1]; 3691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 i2_base = i2 * stride[2] - pad_low[2]; 3701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 i3_base = i3 * stride[3] - pad_low[3]; 3711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 3721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins float val = init; 3731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) { 3741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i1_win = 0; i1_win < window[1]; ++i1_win) { 3751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i2_win = 0; i2_win < window[2]; ++i2_win) { 3761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i3_win = 0; i3_win < window[3]; ++i3_win) { 3771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 && 3781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins i2_base + i2_win >= 0 && i3_base + i3_win >= 0 && 3791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins i0_base + i0_win < operand.n1() && 3801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins i1_base + i1_win < operand.n2() && 3811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins i2_base + i2_win < operand.n3() && 3821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins i3_base + i3_win < operand.n4()) { 3832d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi val = reduce_func( 3842d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi val, operand(i0_base + i0_win, i1_base + i1_win, 3852d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi i2_base + i2_win, i3_base + i3_win)); 3861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 3871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 3881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 3891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 3901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 3911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins (*result)(i0, i1, i2, i3) = val; 3921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 3931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 3941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 3951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 3961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return result; 3971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 3981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 3992d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi/* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::ReduceWindow4DAdd( 4002d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi const Array4D<float>& operand, float init, 4012d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi const tensorflow::gtl::ArraySlice<int64>& window, 4022d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding) { 4032d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; }; 4042d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi return ReduceWindow4DGeneric(operand, init, add_reduce, window, stride, 4052d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi padding); 4062d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi} 4072d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi 4081464b9930de871fd11870941963253670f737c23A. Unique TensorFlower/* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::BatchNorm4D( 4091464b9930de871fd11870941963253670f737c23A. Unique TensorFlower const Array4D<float>& input, const Array4D<float>& mean, 4101464b9930de871fd11870941963253670f737c23A. Unique TensorFlower const Array4D<float>& var, const Array4D<float>& scale, 4111464b9930de871fd11870941963253670f737c23A. Unique TensorFlower const Array4D<float>& offset, float epsilon) { 4121464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto normalized = 4131464b9930de871fd11870941963253670f737c23A. Unique TensorFlower *MapArray4D(input, mean, [](float a, float b) { return a - b; }); 4141464b9930de871fd11870941963253670f737c23A. Unique TensorFlower normalized = *MapArray4D(normalized, var, [&](float a, float b) { 4151464b9930de871fd11870941963253670f737c23A. Unique TensorFlower return a / std::sqrt(b + epsilon); 4161464b9930de871fd11870941963253670f737c23A. Unique TensorFlower }); 4171464b9930de871fd11870941963253670f737c23A. Unique TensorFlower normalized = 4181464b9930de871fd11870941963253670f737c23A. Unique TensorFlower *MapArray4D(normalized, scale, [](float a, float b) { return a * b; }); 4191464b9930de871fd11870941963253670f737c23A. Unique TensorFlower return MapArray4D(normalized, offset, [](float a, float b) { return a + b; }); 4201464b9930de871fd11870941963253670f737c23A. Unique TensorFlower} 4211464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 4221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array4D<float>> 4231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsReferenceUtil::SelectAndScatter4DGePlus( 4241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const Array4D<float>& operand, const Array4D<float>& source, float init, 4251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const tensorflow::gtl::ArraySlice<int64>& window, 4261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const tensorflow::gtl::ArraySlice<int64>& stride, bool same_padding) { 4271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins Padding padding = same_padding ? Padding::kSame : Padding::kValid; 4281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto result = MakeUnique<Array4D<float>>(operand.n1(), operand.n2(), 4291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins operand.n3(), operand.n4()); 4301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(), 4311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins operand.n4()}; 4321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding); 4331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // Fill the output, with the initial value. 4341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins result->Fill(init); 4351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 4361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::vector<int64> window_counts(window.size(), 0); 4371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::vector<int64> pad_low(window.size(), 0); 4381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i = 0; i < window.size(); ++i) { 4391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins window_counts[i] = 4401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins WindowCount(dim_lengths[i], window[i], stride[i], padding); 4411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins pad_low[i] = padding_both[i].first; 4421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 4431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins CHECK_EQ(window_counts[0], source.n1()); 4441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins CHECK_EQ(window_counts[1], source.n2()); 4451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins CHECK_EQ(window_counts[2], source.n3()); 4461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins CHECK_EQ(window_counts[3], source.n4()); 4471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 4481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // Do a full 4D select and Scatter. 4491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i0 = 0; i0 < window_counts[0]; ++i0) { 4501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i1 = 0; i1 < window_counts[1]; ++i1) { 4511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i2 = 0; i2 < window_counts[2]; ++i2) { 4521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i3 = 0; i3 < window_counts[3]; ++i3) { 4531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // Now we are inside a window and need to find the max and the argmax. 4541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 i0_base = i0 * stride[0] - pad_low[0]; 4551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 i1_base = i1 * stride[1] - pad_low[1]; 4561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 i2_base = i2 * stride[2] - pad_low[2]; 4571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 i3_base = i3 * stride[3] - pad_low[3]; 4581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 scatter_0 = (i0_base >= 0) ? i0_base : 0; 4591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 scatter_1 = (i1_base >= 0) ? i1_base : 0; 4601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 scatter_2 = (i2_base >= 0) ? i2_base : 0; 4611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 scatter_3 = (i3_base >= 0) ? i3_base : 0; 4621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins float val = operand(scatter_0, scatter_1, scatter_2, scatter_3); 4631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) { 4641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i1_win = 0; i1_win < window[1]; ++i1_win) { 4651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i2_win = 0; i2_win < window[2]; ++i2_win) { 4661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i3_win = 0; i3_win < window[3]; ++i3_win) { 4671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 && 4681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins i2_base + i2_win >= 0 && i3_base + i3_win >= 0 && 4691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins i0_base + i0_win < operand.n1() && 4701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins i1_base + i1_win < operand.n2() && 4711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins i2_base + i2_win < operand.n3() && 4721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins i3_base + i3_win < operand.n4()) { 4731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins float tmp = operand(i0_base + i0_win, i1_base + i1_win, 4741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins i2_base + i2_win, i3_base + i3_win); 4751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins if (tmp >= val) { 4761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins val = tmp; 4771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins scatter_0 = i0_base + i0_win; 4781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins scatter_1 = i1_base + i1_win; 4791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins scatter_2 = i2_base + i2_win; 4801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins scatter_3 = i3_base + i3_win; 4811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 4821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 4831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 4841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 4851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 4861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 4871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins (*result)(scatter_0, scatter_1, scatter_2, scatter_3) += 4881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins source(i0, i1, i2, i3); 4891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 4901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 4911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 4921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 4931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return result; 4941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 4951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 4961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array4D<float>> 4971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsReferenceUtil::ConvArray4DGeneralDimensions( 4981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const Array4D<float>& lhs, const Array4D<float>& rhs, 4991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::pair<int64, int64> kernel_stride, Padding padding, 5001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ConvolutionDimensionNumbers dimension_numbers) { 5011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return ConvArray4DGeneralDimensionsDilated(lhs, rhs, kernel_stride, padding, 502f4b8d21b8e41636b6e61f0a1de753430108d2ee7A. Unique TensorFlower {1, 1}, {1, 1}, 503f4b8d21b8e41636b6e61f0a1de753430108d2ee7A. Unique TensorFlower std::move(dimension_numbers)); 5041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 5051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 5061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array4D<float>> 5071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsReferenceUtil::ConvArray4DGeneralDimensionsDilated( 5081e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const Array4D<float>& lhs, const Array4D<float>& rhs, 5091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::pair<int64, int64> kernel_stride, Padding padding, 5101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::pair<int64, int64> lhs_dilation, std::pair<int64, int64> rhs_dilation, 5111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ConvolutionDimensionNumbers dnums) { 512253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu HloComputation::Builder b("ConvArray4DGeneralDimensionDilated"); 513253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu auto lhs_literal = Literal::CreateR4FromArray4D<float>(lhs); 514253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu auto rhs_literal = Literal::CreateR4FromArray4D<float>(rhs); 515253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu 516253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu std::array<int64, 2> ordered_kernel_strides; 517253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu std::array<int64, 2> ordered_input_dimensions; 518253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu std::array<int64, 2> ordered_kernel_dimensions; 519253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu if (dnums.kernel_spatial_dimensions(0) > dnums.kernel_spatial_dimensions(1)) { 520253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu ordered_kernel_strides[0] = kernel_stride.second; 521253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu ordered_kernel_strides[1] = kernel_stride.first; 522253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu } else { 523253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu ordered_kernel_strides[0] = kernel_stride.first; 524253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu ordered_kernel_strides[1] = kernel_stride.second; 5251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 5261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 527253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu ordered_input_dimensions[0] = 528102bfdfd830f4dab6e00371e63a82561e1246518David Majnemer lhs_literal->shape().dimensions(dnums.input_spatial_dimensions(0)); 529253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu ordered_input_dimensions[1] = 530102bfdfd830f4dab6e00371e63a82561e1246518David Majnemer lhs_literal->shape().dimensions(dnums.input_spatial_dimensions(1)); 531253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu ordered_kernel_dimensions[0] = 532253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(0)); 533253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu ordered_kernel_dimensions[1] = 534253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(1)); 535253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu 536253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu std::vector<std::pair<int64, int64>> paddings = 537253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu MakePadding(ordered_input_dimensions, ordered_kernel_dimensions, 538253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu ordered_kernel_strides, padding); 539253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu CHECK_EQ(paddings.size(), 2); 540253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu 541253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu Window window; 542253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu 543253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu WindowDimension dim; 544253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu dim.set_size( 545253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(0))); 546253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu dim.set_stride(kernel_stride.first); 547253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu dim.set_padding_low(paddings[0].first); 548253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu dim.set_padding_high(paddings[0].second); 549253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu dim.set_window_dilation(rhs_dilation.first); 550253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu dim.set_base_dilation(lhs_dilation.first); 551253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu *window.add_dimensions() = dim; 552253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu 553253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu WindowDimension dim2; 554253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu dim2.set_size( 555253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(1))); 556253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu dim2.set_stride(kernel_stride.second); 557253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu dim2.set_padding_low(paddings[1].first); 558253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu dim2.set_padding_high(paddings[1].second); 559253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu dim2.set_window_dilation(rhs_dilation.second); 560253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu dim2.set_base_dilation(lhs_dilation.second); 561253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu *window.add_dimensions() = dim2; 562253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu 563253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu const Shape& shape = 564253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu ShapeInference::InferConvolveShape(lhs_literal->shape(), 565253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu rhs_literal->shape(), window, dnums) 566253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu .ConsumeValueOrDie(); 567253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu 568253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu HloInstruction* lhs_instruction = 569253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal))); 570253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu HloInstruction* rhs_instruction = 571253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); 572253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu 573253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu b.AddInstruction(HloInstruction::CreateConvolve( 574253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu shape, lhs_instruction, rhs_instruction, window, dnums)); 575b1c10555afe9ad4ebebbd83eb31dbf8006d7980bMark Heffernan HloModule module("ReferenceUtil"); 576b1c10555afe9ad4ebebbd83eb31dbf8006d7980bMark Heffernan auto computation = module.AddEntryComputation(b.Build()); 577253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu 578253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu HloEvaluator evaluator; 579253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu std::unique_ptr<Literal> result_literal = 580713d45278491d792c525344de6038a61ebcb2136Kay Zhu evaluator.Evaluate<const Literal*>(*computation, {}).ConsumeValueOrDie(); 581253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu 582253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu CHECK_EQ(ShapeUtil::Rank(result_literal->shape()), 4); 583b01346de8b5893a09d50ff4d9c80ca442a327a76Kay Zhu auto result = 584253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu MakeUnique<Array4D<float>>(result_literal->shape().dimensions(0), 585253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu result_literal->shape().dimensions(1), 586253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu result_literal->shape().dimensions(2), 587253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu result_literal->shape().dimensions(3)); 588253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu 589253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu result->Each([&](tensorflow::gtl::ArraySlice<int64> indices, float* value) { 590253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu *value = result_literal->Get<float>(indices); 591253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu }); 5921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 5931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return result; 5941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 5951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 5961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<std::vector<float>> 5971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsReferenceUtil::ReduceToColArray2D( 5981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const Array2D<float>& matrix, float init, 5997e08b5c7ae5a10810948095b96d2da53e816e446Eli Bendersky const std::function<float(float, float)>& reduce_function) { 6001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 rows = matrix.height(); 6011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 cols = matrix.width(); 6021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto result = MakeUnique<std::vector<float>>(); 6031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i = 0; i < rows; ++i) { 6041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins float acc = init; 6051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 j = 0; j < cols; ++j) { 6061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins acc = reduce_function(acc, matrix(i, j)); 6071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 6081e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins result->push_back(acc); 6091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 6101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return result; 6111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 6121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 6131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<std::vector<float>> 6141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsReferenceUtil::ReduceToRowArray2D( 6151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const Array2D<float>& matrix, float init, 6167e08b5c7ae5a10810948095b96d2da53e816e446Eli Bendersky const std::function<float(float, float)>& reduce_function) { 6171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 rows = matrix.height(); 6181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 cols = matrix.width(); 6191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto result = MakeUnique<std::vector<float>>(); 6201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i = 0; i < cols; ++i) { 6211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins float acc = init; 6221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 j = 0; j < rows; ++j) { 6231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins acc = reduce_function(acc, matrix(j, i)); 6241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 6251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins result->push_back(acc); 6261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 6271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return result; 6281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 6291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 6301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/*static*/ std::vector<float> ReferenceUtil::Reduce4DTo1D( 6311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const Array4D<float>& array, float init, 6321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins tensorflow::gtl::ArraySlice<int64> dims, 6337e08b5c7ae5a10810948095b96d2da53e816e446Eli Bendersky const std::function<float(float, float)>& reduce_function) { 6341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::vector<float> result; 6351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins CHECK_EQ(dims.size(), 3); 6361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const std::set<int64> dim_set(dims.begin(), dims.end()); 6371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins CHECK_EQ(dim_set.size(), 3); 6381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 a0 = 0; a0 == 0 || (!dim_set.count(0) && a0 < array.n1()); ++a0) { 6391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 a1 = 0; a1 == 0 || (!dim_set.count(1) && a1 < array.n2()); 6401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ++a1) { 6411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 a2 = 0; a2 == 0 || (!dim_set.count(2) && a2 < array.n3()); 6421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ++a2) { 6431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 a3 = 0; a3 == 0 || (!dim_set.count(3) && a3 < array.n4()); 6441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ++a3) { 6451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins float accumulator = init; 6461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i0 = 0; i0 == 0 || (dim_set.count(0) && i0 < array.n1()); 6471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ++i0) { 6481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i1 = 0; i1 == 0 || (dim_set.count(1) && i1 < array.n2()); 6491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ++i1) { 6501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i2 = 0; 6511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins i2 == 0 || (dim_set.count(2) && i2 < array.n3()); ++i2) { 6521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i3 = 0; 6531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins i3 == 0 || (dim_set.count(3) && i3 < array.n4()); ++i3) { 6541654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar // Handle zero-sized arrays. 6551654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar if (array.n1() > 0 && array.n2() > 0 && array.n3() > 0 && 6561654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar array.n4() > 0) { 6571654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar accumulator = reduce_function( 6581654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar accumulator, array(a0 + i0, a1 + i1, a2 + i2, a3 + i3)); 6591654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar } 6601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 6611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 6621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 6631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 6641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins result.push_back(accumulator); 6651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 6661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 6671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 6681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 6691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return result; 6701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 6711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 6721464b9930de871fd11870941963253670f737c23A. Unique TensorFlower/* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::Broadcast1DTo4D( 6731464b9930de871fd11870941963253670f737c23A. Unique TensorFlower const std::vector<float>& array, const std::vector<int64>& bounds, 6741464b9930de871fd11870941963253670f737c23A. Unique TensorFlower int64 broadcast_from_dim) { 6751464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto result = 6761464b9930de871fd11870941963253670f737c23A. Unique TensorFlower MakeUnique<Array4D<float>>(bounds[0], bounds[1], bounds[2], bounds[3]); 6771464b9930de871fd11870941963253670f737c23A. Unique TensorFlower for (int64 i = 0; i < result->n1(); ++i) { 6781464b9930de871fd11870941963253670f737c23A. Unique TensorFlower for (int64 j = 0; j < result->n2(); ++j) { 6791464b9930de871fd11870941963253670f737c23A. Unique TensorFlower for (int64 k = 0; k < result->n3(); ++k) { 6801464b9930de871fd11870941963253670f737c23A. Unique TensorFlower for (int64 l = 0; l < result->n4(); ++l) { 6811464b9930de871fd11870941963253670f737c23A. Unique TensorFlower switch (broadcast_from_dim) { 6821464b9930de871fd11870941963253670f737c23A. Unique TensorFlower case 0: 6831464b9930de871fd11870941963253670f737c23A. Unique TensorFlower (*result)(i, j, k, l) = array[i]; 6841464b9930de871fd11870941963253670f737c23A. Unique TensorFlower break; 6851464b9930de871fd11870941963253670f737c23A. Unique TensorFlower case 1: 6861464b9930de871fd11870941963253670f737c23A. Unique TensorFlower (*result)(i, j, k, l) = array[j]; 6871464b9930de871fd11870941963253670f737c23A. Unique TensorFlower break; 6881464b9930de871fd11870941963253670f737c23A. Unique TensorFlower case 2: 6891464b9930de871fd11870941963253670f737c23A. Unique TensorFlower (*result)(i, j, k, l) = array[k]; 6901464b9930de871fd11870941963253670f737c23A. Unique TensorFlower break; 6911464b9930de871fd11870941963253670f737c23A. Unique TensorFlower case 3: 6921464b9930de871fd11870941963253670f737c23A. Unique TensorFlower (*result)(i, j, k, l) = array[l]; 6931464b9930de871fd11870941963253670f737c23A. Unique TensorFlower break; 6941464b9930de871fd11870941963253670f737c23A. Unique TensorFlower default: 6951464b9930de871fd11870941963253670f737c23A. Unique TensorFlower break; 6961464b9930de871fd11870941963253670f737c23A. Unique TensorFlower } 6971464b9930de871fd11870941963253670f737c23A. Unique TensorFlower } 6981464b9930de871fd11870941963253670f737c23A. Unique TensorFlower } 6991464b9930de871fd11870941963253670f737c23A. Unique TensorFlower } 7001464b9930de871fd11870941963253670f737c23A. Unique TensorFlower } 7011464b9930de871fd11870941963253670f737c23A. Unique TensorFlower return result; 7021464b9930de871fd11870941963253670f737c23A. Unique TensorFlower} 7031464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 7041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::Reduce3DTo2D( 7051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const Array3D<float>& array, float init, 7061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins tensorflow::gtl::ArraySlice<int64> dims, 7077e08b5c7ae5a10810948095b96d2da53e816e446Eli Bendersky const std::function<float(float, float)>& reduce_function) { 7081e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins CHECK_EQ(dims.size(), 1); 7091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 rows = dims[0] == 0 ? array.n2() : array.n1(); 7101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 cols = dims[0] == 2 ? array.n2() : array.n3(); 7111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto result = MakeUnique<Array2D<float>>(rows, cols); 7121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins result->Fill(init); 7131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int i0 = 0; i0 < array.n1(); ++i0) { 7141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int i1 = 0; i1 < array.n2(); ++i1) { 7151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int i2 = 0; i2 < array.n3(); ++i2) { 7161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 row = dims[0] == 0 ? i1 : i0; 7171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 col = dims[0] == 2 ? i1 : i2; 7181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins (*result)(row, col) = 7191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins reduce_function((*result)(row, col), array(i0, i1, i2)); 7201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 7211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 7221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 7231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return result; 7241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 7251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 7261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MapArray2D( 7271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const Array2D<float>& matrix, 7281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const std::function<float(float)>& map_function) { 7291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 rows = matrix.height(); 7301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 cols = matrix.width(); 7311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto result = MakeUnique<Array2D<float>>(rows, cols); 7321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i = 0; i < rows; ++i) { 7331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 j = 0; j < cols; ++j) { 7341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins (*result)(i, j) = map_function(matrix(i, j)); 7351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 7361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 7371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return result; 7381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 7391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 7401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MapArray2D( 7411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const Array2D<float>& lhs, const Array2D<float>& rhs, 7421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const std::function<float(float, float)>& map_function) { 7431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins CHECK_EQ(lhs.height(), rhs.height()); 7441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins CHECK_EQ(lhs.width(), rhs.width()); 7451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 rows = lhs.height(); 7461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 cols = rhs.width(); 7471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto result = MakeUnique<Array2D<float>>(rows, cols); 7481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i = 0; i < rows; ++i) { 7491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 j = 0; j < cols; ++j) { 7501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins (*result)(i, j) = map_function(lhs(i, j), rhs(i, j)); 7511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 7521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 7531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return result; 7541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 7551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 7561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MapWithIndexArray2D( 7571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const Array2D<float>& matrix, 7581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const std::function<float(float, int64, int64)>& map_function) { 7591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 rows = matrix.height(); 7601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 cols = matrix.width(); 7611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto result = MakeUnique<Array2D<float>>(rows, cols); 7621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i = 0; i < rows; ++i) { 7631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 j = 0; j < cols; ++j) { 7641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins (*result)(i, j) = map_function(matrix(i, j), i, j); 7651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 7661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 7671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return result; 7681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 7691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 7701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} // namespace xla 771