reference_util.cc revision 253bcbb71bdd1f9f2609b085dce90fe9b31cbd5a
11e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 21e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 31e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsLicensed under the Apache License, Version 2.0 (the "License"); 41e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsyou may not use this file except in compliance with the License. 51e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsYou may obtain a copy of the License at 61e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 71e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins http://www.apache.org/licenses/LICENSE-2.0 81e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 91e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsUnless required by applicable law or agreed to in writing, software 101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsdistributed under the License is distributed on an "AS IS" BASIS, 111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsSee the License for the specific language governing permissions and 131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinslimitations under the License. 141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins==============================================================================*/ 151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/reference_util.h" 171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include <array> 19f4b8d21b8e41636b6e61f0a1de753430108d2ee7A. Unique TensorFlower#include <utility> 201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/client/computation_builder.h" 221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" 23253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu#include "tensorflow/compiler/xla/service/hlo_evaluator.h" 24253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu#include "tensorflow/compiler/xla/service/hlo_instruction.h" 25253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu#include "tensorflow/compiler/xla/service/shape_inference.h" 261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/window_util.h" 271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/xla_data.pb.h" 281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/lib/math/math_util.h" 291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/platform/logging.h" 301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsnamespace xla { 321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::TransposeArray2D( 341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const Array2D<float>& operand) { 351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto result = MakeUnique<Array2D<float>>(operand.width(), operand.height()); 361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 w = 0; w < operand.width(); ++w) { 371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 h = 0; h < operand.height(); ++h) { 381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins (*result)(w, h) = operand(h, w); 391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return result; 431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MatmulArray2D( 461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const Array2D<float>& lhs, const Array2D<float>& rhs) { 471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins CHECK_EQ(lhs.width(), rhs.height()); 481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int m = lhs.height(); 491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int n = rhs.width(); 501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int k = lhs.width(); 511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto result = MakeUnique<Array2D<float>>(m, n); 521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // Because Eigen is a header-oriented library, make sure that the Eigen code 531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // is the same as the code used by the CPU backend (otherwise the linker will 541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // randomly pick *some* definition). 551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins __xla_cpu_runtime_EigenSingleThreadedMatMulF32( 561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins /*run_options_ptr=*/nullptr, result->data(), rhs.data(), lhs.data(), n, m, 571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins k, 581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins /*transpose_lhs=*/0, 591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins /*transpose_rhs=*/0); 601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return result; 611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<double>> ReferenceUtil::MatmulArray2D( 641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const Array2D<double>& lhs, const Array2D<double>& rhs) { 651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins CHECK_EQ(lhs.width(), rhs.height()); 661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int m = lhs.height(); 671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int n = rhs.width(); 681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int k = lhs.width(); 691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto result = MakeUnique<Array2D<double>>(m, n); 701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // Because Eigen is a header-oriented library, make sure that the Eigen code 711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // is the same as the code used by the CPU backend (otherwise the linker will 721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // randomly pick *some* definition). 731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins __xla_cpu_runtime_EigenSingleThreadedMatMulF64( 741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins /*run_options_ptr=*/nullptr, result->data(), rhs.data(), lhs.data(), n, m, 751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins k, 761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins /*transpose_lhs=*/0, 771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins /*transpose_rhs=*/0); 781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return result; 791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<double>> ReferenceUtil::Array2DF32ToF64( 821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const Array2D<float>& input) { 831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto result = MakeUnique<Array2D<double>>(input.height(), input.width()); 841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 rowno = 0; rowno < input.height(); ++rowno) { 851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 colno = 0; colno < input.height(); ++colno) { 861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins (*result)(rowno, colno) = input(rowno, colno); 871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return result; 901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 929b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower/* static */ std::unique_ptr<Array3D<float>> ReferenceUtil::ConvArray3D( 939b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower const Array3D<float>& lhs, const Array3D<float>& rhs, int64 kernel_stride, 949b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower Padding padding) { 959b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower return ConvArray3DGeneralDimensionsDilated( 969b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower lhs, rhs, kernel_stride, padding, 1, 1, 979b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower ComputationBuilder::CreateDefaultConvDimensionNumbers(1)); 989b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower} 999b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower 1009b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower/*static*/ std::unique_ptr<Array3D<float>> 1019b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlowerReferenceUtil::ConvArray3DGeneralDimensionsDilated( 1029b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower const Array3D<float>& lhs, const Array3D<float>& rhs, int64 kernel_stride, 1039b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower Padding padding, int64 lhs_dilation, int64 rhs_dilation, 1049b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower const ConvolutionDimensionNumbers& dnums) { 1059b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower CHECK_EQ(dnums.spatial_dimensions_size(), 1); 1069b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower // Reuse the code for Array4D-convolution by extending the 3D input into a 4D 1079b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower // array by adding a fourth dummy dimension of size 1 without stride, padding 1089b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower // and dilation. 1099b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower Array4D<float> a4dlhs(lhs.n1(), lhs.n2(), lhs.n3(), 1); 1109b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower a4dlhs.Each( 1119b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower [&](tensorflow::gtl::ArraySlice<int64> indices, float* value_ptr) { 1129b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower CHECK_EQ(indices[3], 0); 1139b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower *value_ptr = lhs.operator()(indices[0], indices[1], indices[2]); 1149b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower }); 1159b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower Array4D<float> a4drhs(rhs.n1(), rhs.n2(), rhs.n3(), 1); 1169b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower a4drhs.Each( 1179b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower [&](tensorflow::gtl::ArraySlice<int64> indices, float* value_ptr) { 1189b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower CHECK_EQ(indices[3], 0); 1199b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower *value_ptr = rhs.operator()(indices[0], indices[1], indices[2]); 1209b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower }); 1219b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower // Add a second dummy spatial dimensions. 1229b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower ConvolutionDimensionNumbers dnums2d = dnums; 1239b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower dnums2d.add_spatial_dimensions(3); 1249b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower dnums2d.add_kernel_spatial_dimensions(3); 1259b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower std::unique_ptr<Array4D<float>> convr4 = ConvArray4DGeneralDimensionsDilated( 1269b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower a4dlhs, a4drhs, {kernel_stride, 1}, padding, {lhs_dilation, 1}, 1279b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower {rhs_dilation, 1}, dnums2d); 1289b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower 1299b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower auto convr3 = MakeUnique<Array3D<float>>(convr4->planes(), convr4->depth(), 1309b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower convr4->height()); 1319b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower convr4->Each( 1329b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower [&](tensorflow::gtl::ArraySlice<int64> indices, float* value_ptr) { 1339b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower CHECK_EQ(indices[3], 0); 1349b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower convr3->operator()(indices[0], indices[1], indices[2]) = *value_ptr; 1359b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower }); 1369b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower return convr3; 1379b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower} 1389b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower 1391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::ConvArray4D( 1401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const Array4D<float>& lhs, const Array4D<float>& rhs, 1411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::pair<int64, int64> kernel_stride, Padding padding) { 1421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return ConvArray4DGeneralDimensions( 1431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins lhs, rhs, kernel_stride, padding, 1441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ComputationBuilder::CreateDefaultConvDimensionNumbers()); 1451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 1461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1477fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee/* static */ std::unique_ptr<Array4D<float>> 1487fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong LeeReferenceUtil::SeparableConvArray4D(const Array4D<float>& input, 1497fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee const Array4D<float>& depthwise_weights, 1507fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee const Array4D<float>& pointwise_weights, 1517fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee std::pair<int64, int64> kernel_stride, 1527fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee Padding padding) { 1537fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee const int64 depth_multiplier = depthwise_weights.planes(); 1547fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee CHECK_EQ(pointwise_weights.depth(), input.depth() * depth_multiplier); 1557fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee 1567fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee // Combine the two weights by reducing the depth_multiplier, so that we can 1577fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee // apply a single convolution on the combined weights. 1587fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee Array4D<float> weights(pointwise_weights.planes(), input.depth(), 1597fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee depthwise_weights.height(), depthwise_weights.width()); 1607fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee for (int64 kx = 0; kx < depthwise_weights.width(); ++kx) { 1617fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee for (int64 ky = 0; ky < depthwise_weights.height(); ++ky) { 1627fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee for (int64 kz = 0; kz < input.depth(); ++kz) { 1637fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee for (int64 out = 0; out < pointwise_weights.planes(); ++out) { 1647fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee float weight = 0.0; 1657fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee for (int64 depth = 0; depth < depth_multiplier; ++depth) { 1667fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee weight += 1677fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee depthwise_weights(depth, kz, ky, kx) * 1687fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee pointwise_weights(out, depth + kz * depth_multiplier, 0, 0); 1697fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee } 1707fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee weights(out, kz, ky, kx) = weight; 1717fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee } 1727fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee } 1737fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee } 1747fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee } 1757fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee 1767fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee return ConvArray4D(input, weights, kernel_stride, padding); 1777fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee} 1787fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee 1791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ int64 ReferenceUtil::WindowCount(int64 unpadded_width, 1801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 window_len, int64 stride, 1811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins Padding padding) { 1821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins if (padding == Padding::kValid) { 1831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return window_util::StridedBound(unpadded_width, window_len, stride); 1841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 1851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return tensorflow::MathUtil::CeilOfRatio(unpadded_width, stride); 1861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 1871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1880034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi/* static */ std::unique_ptr<std::vector<float>> 1890034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo OguntebiReferenceUtil::ReduceWindow1DGeneric( 1900034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi const tensorflow::gtl::ArraySlice<float>& operand, float init, 1910034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi const std::function<float(float, float)>& reduce_func, 1920034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi const tensorflow::gtl::ArraySlice<int64>& window, 1930034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding) { 1940034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi std::vector<int64> dim_lengths{static_cast<int64>(operand.size())}; 1950034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding); 1960034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi 1970034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi std::vector<int64> window_counts(window.size(), 0); 1980034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi std::vector<int64> pad_low(window.size(), 0); 1990034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi for (int64 i = 0; i < window.size(); ++i) { 2000034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi window_counts[i] = 2010034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi WindowCount(dim_lengths[i], window[i], stride[i], padding); 2020034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi pad_low[i] = padding_both[i].first; 2030034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi } 2040034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi auto result = MakeUnique<std::vector<float>>(window_counts[0]); 2050034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi 2060034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi // Do a full 1D reduce window. 2070034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi for (int64 i0 = 0; i0 < window_counts[0]; ++i0) { 2080034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi int64 i0_base = i0 * stride[0] - pad_low[0]; 2090034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi 2100034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi float val = init; 2110034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) { 2120034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi if (i0_base + i0_win >= 0 && i0_base + i0_win < dim_lengths[0]) { 2130034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi val = reduce_func(val, operand[i0_base + i0_win]); 2140034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi } 2150034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi } 2160034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi (*result)[i0] = val; 2170034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi } 2180034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi return result; 2190034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi} 2200034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi 2210034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi/* static */ std::unique_ptr<std::vector<float>> 2220034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo OguntebiReferenceUtil::ReduceWindow1DAdd( 2230034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi const tensorflow::gtl::ArraySlice<float>& operand, float init, 2240034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi const tensorflow::gtl::ArraySlice<int64>& window, 2250034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding) { 2260034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; }; 2270034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi return ReduceWindow1DGeneric(operand, init, add_reduce, window, stride, 2280034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi padding); 2290034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi} 2300034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi 2316bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::ReduceWindow2DAdd( 2326bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi const Array2D<float>& operand, float init, 2336bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi const tensorflow::gtl::ArraySlice<int64>& window, 2346bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding) { 2356bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi std::vector<int64> dim_lengths{operand.height(), operand.width()}; 2366bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding); 2376bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi 2386bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi std::vector<int64> window_counts(window.size(), 0); 2396bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi std::vector<int64> pad_low(window.size(), 0); 2406bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi for (int64 i = 0; i < window.size(); ++i) { 2416bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi window_counts[i] = 2426bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi WindowCount(dim_lengths[i], window[i], stride[i], padding); 2436bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi pad_low[i] = padding_both[i].first; 2446bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi } 2456bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi auto result = MakeUnique<Array2D<float>>(window_counts[0], window_counts[1]); 2466bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi 2476bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi // Do a full 2D reduce window. 2486bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi for (int64 i0 = 0; i0 < window_counts[0]; ++i0) { 2496bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi for (int64 i1 = 0; i1 < window_counts[1]; ++i1) { 2506bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi int64 i0_base = i0 * stride[0] - pad_low[0]; 2516bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi int64 i1_base = i1 * stride[1] - pad_low[1]; 2526bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi 2536bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi float val = init; 2546bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) { 2556bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi for (int64 i1_win = 0; i1_win < window[1]; ++i1_win) { 2566bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 && 2576bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi i0_base + i0_win < operand.n1() && 2586bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi i1_base + i1_win < operand.n2()) { 2596bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi val += operand(i0_base + i0_win, i1_base + i1_win); 2606bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi } 2616bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi } 2626bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi } 2636bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi (*result)(i0, i1) = val; 2646bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi } 2656bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi } 2666bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi return result; 2676bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi} 2686bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi 2692d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi/* static */ std::unique_ptr<Array4D<float>> 2702d69270342d2a5e46446e02e9273e7da79f00accTayo OguntebiReferenceUtil::ReduceWindow4DGeneric( 2711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const Array4D<float>& operand, float init, 2722d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi const std::function<float(float, float)>& reduce_func, 2731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const tensorflow::gtl::ArraySlice<int64>& window, 2741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding) { 2751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(), 2761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins operand.n4()}; 277f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune return ReduceWindow4DGeneric( 278f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune operand, init, reduce_func, window, stride, 279f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune xla::MakePadding(dim_lengths, window, stride, padding)); 280f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune} 281f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune 282f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune/* static */ std::unique_ptr<Array4D<float>> 283f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt RouneReferenceUtil::ReduceWindow4DGeneric( 284f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune const Array4D<float>& operand, float init, 285f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune const std::function<float(float, float)>& reduce_func, 286f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune const tensorflow::gtl::ArraySlice<int64>& window, 287f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune const tensorflow::gtl::ArraySlice<int64>& stride, 288f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune const tensorflow::gtl::ArraySlice<std::pair<int64, int64>>& padding) { 289f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(), 290f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune operand.n4()}; 2911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 2921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::vector<int64> window_counts(window.size(), 0); 2931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::vector<int64> pad_low(window.size(), 0); 2941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i = 0; i < window.size(); ++i) { 295f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune int64 padded_width = padding[i].first + dim_lengths[i] + padding[i].second; 2961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins window_counts[i] = 297f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune window_util::StridedBound(padded_width, window[i], stride[i]); 298f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune pad_low[i] = padding[i].first; 2991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 3001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto result = MakeUnique<Array4D<float>>(window_counts[0], window_counts[1], 3011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins window_counts[2], window_counts[3]); 3021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // Do a full 4D reduce window. 3031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i0 = 0; i0 < window_counts[0]; ++i0) { 3041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i1 = 0; i1 < window_counts[1]; ++i1) { 3051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i2 = 0; i2 < window_counts[2]; ++i2) { 3061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i3 = 0; i3 < window_counts[3]; ++i3) { 3071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 i0_base = i0 * stride[0] - pad_low[0]; 3081e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 i1_base = i1 * stride[1] - pad_low[1]; 3091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 i2_base = i2 * stride[2] - pad_low[2]; 3101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 i3_base = i3 * stride[3] - pad_low[3]; 3111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 3121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins float val = init; 3131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) { 3141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i1_win = 0; i1_win < window[1]; ++i1_win) { 3151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i2_win = 0; i2_win < window[2]; ++i2_win) { 3161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i3_win = 0; i3_win < window[3]; ++i3_win) { 3171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 && 3181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins i2_base + i2_win >= 0 && i3_base + i3_win >= 0 && 3191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins i0_base + i0_win < operand.n1() && 3201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins i1_base + i1_win < operand.n2() && 3211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins i2_base + i2_win < operand.n3() && 3221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins i3_base + i3_win < operand.n4()) { 3232d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi val = reduce_func( 3242d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi val, operand(i0_base + i0_win, i1_base + i1_win, 3252d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi i2_base + i2_win, i3_base + i3_win)); 3261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 3271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 3281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 3291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 3301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 3311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins (*result)(i0, i1, i2, i3) = val; 3321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 3331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 3341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 3351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 3361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return result; 3371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 3381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 3392d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi/* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::ReduceWindow4DAdd( 3402d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi const Array4D<float>& operand, float init, 3412d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi const tensorflow::gtl::ArraySlice<int64>& window, 3422d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding) { 3432d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; }; 3442d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi return ReduceWindow4DGeneric(operand, init, add_reduce, window, stride, 3452d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi padding); 3462d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi} 3472d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi 3481464b9930de871fd11870941963253670f737c23A. Unique TensorFlower/* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::BatchNorm4D( 3491464b9930de871fd11870941963253670f737c23A. Unique TensorFlower const Array4D<float>& input, const Array4D<float>& mean, 3501464b9930de871fd11870941963253670f737c23A. Unique TensorFlower const Array4D<float>& var, const Array4D<float>& scale, 3511464b9930de871fd11870941963253670f737c23A. Unique TensorFlower const Array4D<float>& offset, float epsilon) { 3521464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto normalized = 3531464b9930de871fd11870941963253670f737c23A. Unique TensorFlower *MapArray4D(input, mean, [](float a, float b) { return a - b; }); 3541464b9930de871fd11870941963253670f737c23A. Unique TensorFlower normalized = *MapArray4D(normalized, var, [&](float a, float b) { 3551464b9930de871fd11870941963253670f737c23A. Unique TensorFlower return a / std::sqrt(b + epsilon); 3561464b9930de871fd11870941963253670f737c23A. Unique TensorFlower }); 3571464b9930de871fd11870941963253670f737c23A. Unique TensorFlower normalized = 3581464b9930de871fd11870941963253670f737c23A. Unique TensorFlower *MapArray4D(normalized, scale, [](float a, float b) { return a * b; }); 3591464b9930de871fd11870941963253670f737c23A. Unique TensorFlower return MapArray4D(normalized, offset, [](float a, float b) { return a + b; }); 3601464b9930de871fd11870941963253670f737c23A. Unique TensorFlower} 3611464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 3621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array4D<float>> 3631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsReferenceUtil::SelectAndScatter4DGePlus( 3641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const Array4D<float>& operand, const Array4D<float>& source, float init, 3651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const tensorflow::gtl::ArraySlice<int64>& window, 3661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const tensorflow::gtl::ArraySlice<int64>& stride, bool same_padding) { 3671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins Padding padding = same_padding ? Padding::kSame : Padding::kValid; 3681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto result = MakeUnique<Array4D<float>>(operand.n1(), operand.n2(), 3691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins operand.n3(), operand.n4()); 3701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(), 3711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins operand.n4()}; 3721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding); 3731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // Fill the output, with the initial value. 3741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins result->Fill(init); 3751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 3761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::vector<int64> window_counts(window.size(), 0); 3771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::vector<int64> pad_low(window.size(), 0); 3781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i = 0; i < window.size(); ++i) { 3791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins window_counts[i] = 3801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins WindowCount(dim_lengths[i], window[i], stride[i], padding); 3811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins pad_low[i] = padding_both[i].first; 3821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 3831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins CHECK_EQ(window_counts[0], source.n1()); 3841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins CHECK_EQ(window_counts[1], source.n2()); 3851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins CHECK_EQ(window_counts[2], source.n3()); 3861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins CHECK_EQ(window_counts[3], source.n4()); 3871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 3881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // Do a full 4D select and Scatter. 3891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i0 = 0; i0 < window_counts[0]; ++i0) { 3901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i1 = 0; i1 < window_counts[1]; ++i1) { 3911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i2 = 0; i2 < window_counts[2]; ++i2) { 3921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i3 = 0; i3 < window_counts[3]; ++i3) { 3931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // Now we are inside a window and need to find the max and the argmax. 3941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 i0_base = i0 * stride[0] - pad_low[0]; 3951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 i1_base = i1 * stride[1] - pad_low[1]; 3961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 i2_base = i2 * stride[2] - pad_low[2]; 3971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 i3_base = i3 * stride[3] - pad_low[3]; 3981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 scatter_0 = (i0_base >= 0) ? i0_base : 0; 3991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 scatter_1 = (i1_base >= 0) ? i1_base : 0; 4001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 scatter_2 = (i2_base >= 0) ? i2_base : 0; 4011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 scatter_3 = (i3_base >= 0) ? i3_base : 0; 4021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins float val = operand(scatter_0, scatter_1, scatter_2, scatter_3); 4031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) { 4041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i1_win = 0; i1_win < window[1]; ++i1_win) { 4051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i2_win = 0; i2_win < window[2]; ++i2_win) { 4061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i3_win = 0; i3_win < window[3]; ++i3_win) { 4071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 && 4081e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins i2_base + i2_win >= 0 && i3_base + i3_win >= 0 && 4091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins i0_base + i0_win < operand.n1() && 4101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins i1_base + i1_win < operand.n2() && 4111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins i2_base + i2_win < operand.n3() && 4121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins i3_base + i3_win < operand.n4()) { 4131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins float tmp = operand(i0_base + i0_win, i1_base + i1_win, 4141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins i2_base + i2_win, i3_base + i3_win); 4151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins if (tmp >= val) { 4161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins val = tmp; 4171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins scatter_0 = i0_base + i0_win; 4181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins scatter_1 = i1_base + i1_win; 4191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins scatter_2 = i2_base + i2_win; 4201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins scatter_3 = i3_base + i3_win; 4211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 4221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 4231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 4241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 4251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 4261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 4271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins (*result)(scatter_0, scatter_1, scatter_2, scatter_3) += 4281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins source(i0, i1, i2, i3); 4291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 4301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 4311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 4321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 4331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return result; 4341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 4351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 4361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array4D<float>> 4371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsReferenceUtil::ConvArray4DGeneralDimensions( 4381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const Array4D<float>& lhs, const Array4D<float>& rhs, 4391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::pair<int64, int64> kernel_stride, Padding padding, 4401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ConvolutionDimensionNumbers dimension_numbers) { 4411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return ConvArray4DGeneralDimensionsDilated(lhs, rhs, kernel_stride, padding, 442f4b8d21b8e41636b6e61f0a1de753430108d2ee7A. Unique TensorFlower {1, 1}, {1, 1}, 443f4b8d21b8e41636b6e61f0a1de753430108d2ee7A. Unique TensorFlower std::move(dimension_numbers)); 4441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 4451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 4461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array4D<float>> 4471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsReferenceUtil::ConvArray4DGeneralDimensionsDilated( 4481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const Array4D<float>& lhs, const Array4D<float>& rhs, 4491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::pair<int64, int64> kernel_stride, Padding padding, 4501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::pair<int64, int64> lhs_dilation, std::pair<int64, int64> rhs_dilation, 4511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ConvolutionDimensionNumbers dnums) { 452253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu HloComputation::Builder b("ConvArray4DGeneralDimensionDilated"); 453253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu auto lhs_literal = Literal::CreateR4FromArray4D<float>(lhs); 454253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu auto rhs_literal = Literal::CreateR4FromArray4D<float>(rhs); 455253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu 456253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu std::array<int64, 2> ordered_kernel_strides; 457253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu std::array<int64, 2> ordered_input_dimensions; 458253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu std::array<int64, 2> ordered_kernel_dimensions; 459253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu if (dnums.kernel_spatial_dimensions(0) > dnums.kernel_spatial_dimensions(1)) { 460253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu ordered_kernel_strides[0] = kernel_stride.second; 461253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu ordered_kernel_strides[1] = kernel_stride.first; 462253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu } else { 463253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu ordered_kernel_strides[0] = kernel_stride.first; 464253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu ordered_kernel_strides[1] = kernel_stride.second; 4651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 4661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 467253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu ordered_input_dimensions[0] = 468253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu lhs_literal->shape().dimensions(dnums.spatial_dimensions(0)); 469253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu ordered_input_dimensions[1] = 470253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu lhs_literal->shape().dimensions(dnums.spatial_dimensions(1)); 471253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu ordered_kernel_dimensions[0] = 472253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(0)); 473253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu ordered_kernel_dimensions[1] = 474253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(1)); 475253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu 476253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu std::vector<std::pair<int64, int64>> paddings = 477253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu MakePadding(ordered_input_dimensions, ordered_kernel_dimensions, 478253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu ordered_kernel_strides, padding); 479253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu CHECK_EQ(paddings.size(), 2); 480253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu 481253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu Window window; 482253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu 483253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu WindowDimension dim; 484253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu dim.set_size( 485253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(0))); 486253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu dim.set_stride(kernel_stride.first); 487253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu dim.set_padding_low(paddings[0].first); 488253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu dim.set_padding_high(paddings[0].second); 489253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu dim.set_window_dilation(rhs_dilation.first); 490253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu dim.set_base_dilation(lhs_dilation.first); 491253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu *window.add_dimensions() = dim; 492253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu 493253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu WindowDimension dim2; 494253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu dim2.set_size( 495253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(1))); 496253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu dim2.set_stride(kernel_stride.second); 497253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu dim2.set_padding_low(paddings[1].first); 498253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu dim2.set_padding_high(paddings[1].second); 499253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu dim2.set_window_dilation(rhs_dilation.second); 500253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu dim2.set_base_dilation(lhs_dilation.second); 501253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu *window.add_dimensions() = dim2; 502253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu 503253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu const Shape& shape = 504253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu ShapeInference::InferConvolveShape(lhs_literal->shape(), 505253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu rhs_literal->shape(), window, dnums) 506253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu .ConsumeValueOrDie(); 507253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu 508253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu HloInstruction* lhs_instruction = 509253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal))); 510253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu HloInstruction* rhs_instruction = 511253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); 512253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu 513253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu b.AddInstruction(HloInstruction::CreateConvolve( 514253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu shape, lhs_instruction, rhs_instruction, window, dnums)); 515253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu 516253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu HloEvaluator evaluator; 517253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu std::unique_ptr<Literal> result_literal = 518253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu evaluator.Evaluate(*b.Build(), {}).ConsumeValueOrDie(); 519253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu 520253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu CHECK_EQ(ShapeUtil::Rank(result_literal->shape()), 4); 521b01346de8b5893a09d50ff4d9c80ca442a327a76Kay Zhu auto result = 522253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu MakeUnique<Array4D<float>>(result_literal->shape().dimensions(0), 523253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu result_literal->shape().dimensions(1), 524253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu result_literal->shape().dimensions(2), 525253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu result_literal->shape().dimensions(3)); 526253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu 527253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu result->Each([&](tensorflow::gtl::ArraySlice<int64> indices, float* value) { 528253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu *value = result_literal->Get<float>(indices); 529253bcbb71bdd1f9f2609b085dce90fe9b31cbd5aKay Zhu }); 5301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 5311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return result; 5321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 5331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 5341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<std::vector<float>> 5351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsReferenceUtil::ReduceToColArray2D( 5361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const Array2D<float>& matrix, float init, 5377e08b5c7ae5a10810948095b96d2da53e816e446Eli Bendersky const std::function<float(float, float)>& reduce_function) { 5381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 rows = matrix.height(); 5391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 cols = matrix.width(); 5401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto result = MakeUnique<std::vector<float>>(); 5411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i = 0; i < rows; ++i) { 5421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins float acc = init; 5431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 j = 0; j < cols; ++j) { 5441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins acc = reduce_function(acc, matrix(i, j)); 5451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 5461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins result->push_back(acc); 5471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 5481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return result; 5491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 5501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 5511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<std::vector<float>> 5521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsReferenceUtil::ReduceToRowArray2D( 5531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const Array2D<float>& matrix, float init, 5547e08b5c7ae5a10810948095b96d2da53e816e446Eli Bendersky const std::function<float(float, float)>& reduce_function) { 5551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 rows = matrix.height(); 5561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 cols = matrix.width(); 5571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto result = MakeUnique<std::vector<float>>(); 5581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i = 0; i < cols; ++i) { 5591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins float acc = init; 5601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 j = 0; j < rows; ++j) { 5611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins acc = reduce_function(acc, matrix(j, i)); 5621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 5631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins result->push_back(acc); 5641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 5651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return result; 5661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 5671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 5681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/*static*/ std::vector<float> ReferenceUtil::Reduce4DTo1D( 5691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const Array4D<float>& array, float init, 5701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins tensorflow::gtl::ArraySlice<int64> dims, 5717e08b5c7ae5a10810948095b96d2da53e816e446Eli Bendersky const std::function<float(float, float)>& reduce_function) { 5721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::vector<float> result; 5731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins CHECK_EQ(dims.size(), 3); 5741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const std::set<int64> dim_set(dims.begin(), dims.end()); 5751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins CHECK_EQ(dim_set.size(), 3); 5761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 a0 = 0; a0 == 0 || (!dim_set.count(0) && a0 < array.n1()); ++a0) { 5771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 a1 = 0; a1 == 0 || (!dim_set.count(1) && a1 < array.n2()); 5781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ++a1) { 5791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 a2 = 0; a2 == 0 || (!dim_set.count(2) && a2 < array.n3()); 5801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ++a2) { 5811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 a3 = 0; a3 == 0 || (!dim_set.count(3) && a3 < array.n4()); 5821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ++a3) { 5831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins float accumulator = init; 5841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i0 = 0; i0 == 0 || (dim_set.count(0) && i0 < array.n1()); 5851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ++i0) { 5861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i1 = 0; i1 == 0 || (dim_set.count(1) && i1 < array.n2()); 5871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ++i1) { 5881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i2 = 0; 5891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins i2 == 0 || (dim_set.count(2) && i2 < array.n3()); ++i2) { 5901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i3 = 0; 5911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins i3 == 0 || (dim_set.count(3) && i3 < array.n4()); ++i3) { 5921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins accumulator = reduce_function( 5931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins accumulator, array(a0 + i0, a1 + i1, a2 + i2, a3 + i3)); 5941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 5951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 5961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 5971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 5981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins result.push_back(accumulator); 5991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 6001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 6011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 6021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 6031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return result; 6041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 6051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 6061464b9930de871fd11870941963253670f737c23A. Unique TensorFlower/* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::Broadcast1DTo4D( 6071464b9930de871fd11870941963253670f737c23A. Unique TensorFlower const std::vector<float>& array, const std::vector<int64>& bounds, 6081464b9930de871fd11870941963253670f737c23A. Unique TensorFlower int64 broadcast_from_dim) { 6091464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto result = 6101464b9930de871fd11870941963253670f737c23A. Unique TensorFlower MakeUnique<Array4D<float>>(bounds[0], bounds[1], bounds[2], bounds[3]); 6111464b9930de871fd11870941963253670f737c23A. Unique TensorFlower for (int64 i = 0; i < result->n1(); ++i) { 6121464b9930de871fd11870941963253670f737c23A. Unique TensorFlower for (int64 j = 0; j < result->n2(); ++j) { 6131464b9930de871fd11870941963253670f737c23A. Unique TensorFlower for (int64 k = 0; k < result->n3(); ++k) { 6141464b9930de871fd11870941963253670f737c23A. Unique TensorFlower for (int64 l = 0; l < result->n4(); ++l) { 6151464b9930de871fd11870941963253670f737c23A. Unique TensorFlower switch (broadcast_from_dim) { 6161464b9930de871fd11870941963253670f737c23A. Unique TensorFlower case 0: 6171464b9930de871fd11870941963253670f737c23A. Unique TensorFlower (*result)(i, j, k, l) = array[i]; 6181464b9930de871fd11870941963253670f737c23A. Unique TensorFlower break; 6191464b9930de871fd11870941963253670f737c23A. Unique TensorFlower case 1: 6201464b9930de871fd11870941963253670f737c23A. Unique TensorFlower (*result)(i, j, k, l) = array[j]; 6211464b9930de871fd11870941963253670f737c23A. Unique TensorFlower break; 6221464b9930de871fd11870941963253670f737c23A. Unique TensorFlower case 2: 6231464b9930de871fd11870941963253670f737c23A. Unique TensorFlower (*result)(i, j, k, l) = array[k]; 6241464b9930de871fd11870941963253670f737c23A. Unique TensorFlower break; 6251464b9930de871fd11870941963253670f737c23A. Unique TensorFlower case 3: 6261464b9930de871fd11870941963253670f737c23A. Unique TensorFlower (*result)(i, j, k, l) = array[l]; 6271464b9930de871fd11870941963253670f737c23A. Unique TensorFlower break; 6281464b9930de871fd11870941963253670f737c23A. Unique TensorFlower default: 6291464b9930de871fd11870941963253670f737c23A. Unique TensorFlower break; 6301464b9930de871fd11870941963253670f737c23A. Unique TensorFlower } 6311464b9930de871fd11870941963253670f737c23A. Unique TensorFlower } 6321464b9930de871fd11870941963253670f737c23A. Unique TensorFlower } 6331464b9930de871fd11870941963253670f737c23A. Unique TensorFlower } 6341464b9930de871fd11870941963253670f737c23A. Unique TensorFlower } 6351464b9930de871fd11870941963253670f737c23A. Unique TensorFlower return result; 6361464b9930de871fd11870941963253670f737c23A. Unique TensorFlower} 6371464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 6381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::Reduce3DTo2D( 6391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const Array3D<float>& array, float init, 6401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins tensorflow::gtl::ArraySlice<int64> dims, 6417e08b5c7ae5a10810948095b96d2da53e816e446Eli Bendersky const std::function<float(float, float)>& reduce_function) { 6421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins CHECK_EQ(dims.size(), 1); 6431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 rows = dims[0] == 0 ? array.n2() : array.n1(); 6441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 cols = dims[0] == 2 ? array.n2() : array.n3(); 6451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto result = MakeUnique<Array2D<float>>(rows, cols); 6461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins result->Fill(init); 6471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int i0 = 0; i0 < array.n1(); ++i0) { 6481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int i1 = 0; i1 < array.n2(); ++i1) { 6491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int i2 = 0; i2 < array.n3(); ++i2) { 6501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 row = dims[0] == 0 ? i1 : i0; 6511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 col = dims[0] == 2 ? i1 : i2; 6521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins (*result)(row, col) = 6531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins reduce_function((*result)(row, col), array(i0, i1, i2)); 6541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 6551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 6561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 6571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return result; 6581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 6591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 6601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MapArray2D( 6611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const Array2D<float>& matrix, 6621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const std::function<float(float)>& map_function) { 6631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 rows = matrix.height(); 6641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 cols = matrix.width(); 6651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto result = MakeUnique<Array2D<float>>(rows, cols); 6661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i = 0; i < rows; ++i) { 6671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 j = 0; j < cols; ++j) { 6681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins (*result)(i, j) = map_function(matrix(i, j)); 6691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 6701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 6711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return result; 6721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 6731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 6741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MapArray2D( 6751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const Array2D<float>& lhs, const Array2D<float>& rhs, 6761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const std::function<float(float, float)>& map_function) { 6771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins CHECK_EQ(lhs.height(), rhs.height()); 6781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins CHECK_EQ(lhs.width(), rhs.width()); 6791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 rows = lhs.height(); 6801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 cols = rhs.width(); 6811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto result = MakeUnique<Array2D<float>>(rows, cols); 6821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i = 0; i < rows; ++i) { 6831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 j = 0; j < cols; ++j) { 6841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins (*result)(i, j) = map_function(lhs(i, j), rhs(i, j)); 6851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 6861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 6871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return result; 6881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 6891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 6901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MapWithIndexArray2D( 6911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const Array2D<float>& matrix, 6921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const std::function<float(float, int64, int64)>& map_function) { 6931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 rows = matrix.height(); 6941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 cols = matrix.width(); 6951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto result = MakeUnique<Array2D<float>>(rows, cols); 6961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i = 0; i < rows; ++i) { 6971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 j = 0; j < cols; ++j) { 6981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins (*result)(i, j) = map_function(matrix(i, j), i, j); 6991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 7001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 7011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return result; 7021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 7031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 7041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::PadArray2D( 7051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const Array2D<float>& operand, const PaddingConfig& padding, 7061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const float pad) { 7071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 in0 = operand.n1(); 7081e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 high_padding0 = padding.dimensions(0).edge_padding_high(); 7091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 low_padding0 = padding.dimensions(0).edge_padding_low(); 7101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 interior_padding0 = padding.dimensions(0).interior_padding(); 7111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 out0 = 7121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins in0 + low_padding0 + high_padding0 + (in0 - 1) * interior_padding0; 7135aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan 7141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 in1 = operand.n2(); 7151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 high_padding1 = padding.dimensions(1).edge_padding_high(); 7161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 low_padding1 = padding.dimensions(1).edge_padding_low(); 7171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 interior_padding1 = padding.dimensions(1).interior_padding(); 7181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 out1 = 7191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins in1 + low_padding1 + high_padding1 + (in1 - 1) * interior_padding1; 7205aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan 7211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto result = MakeUnique<Array2D<float>>(out0, out1); 7221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins result->Fill(pad); 7235aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan int64 o0 = low_padding0; 7245aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan for (int64 i0 = 0; i0 < in0; ++i0) { 7255aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan int64 o1 = low_padding1; 7265aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan for (int64 i1 = 0; i1 < in1; ++i1) { 7275aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan if (o0 >= 0 && o1 >= 0 && o0 < out0 && o1 < out1) { 7285aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan (*result)(o0, o1) = operand(i0, i1); 7295aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan } 7305aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan o1 += interior_padding1 + 1; 7311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 7325aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan o0 += interior_padding0 + 1; 7331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 7341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return result; 7351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 7361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 7379b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower/* static */ Array3D<float> ReferenceUtil::PadArray3D( 7389b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower const Array3D<float>& operand, const PaddingConfig& padding, 7399b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower const float pad) { 7409b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower CHECK_EQ(padding.dimensions_size(), 3); 7419b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower 7429b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower const std::vector<int64> input_bounds = {operand.n1(), operand.n2(), 7439b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower operand.n3()}; 7449b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower std::vector<int64> pad_low(3); 7459b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower std::vector<int64> pad_high(3); 7469b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower std::vector<int64> pad_interior(3); 7479b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower std::vector<int64> output_bounds(3); 7489b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower for (int64 i = 0; i < 3; ++i) { 7499b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower pad_low[i] = padding.dimensions(i).edge_padding_low(); 7509b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower pad_high[i] = padding.dimensions(i).edge_padding_high(); 7519b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower CHECK_LE(0, pad_low[i]); 7529b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower CHECK_LE(0, pad_high[i]); 7539b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower CHECK_LE(0, padding.dimensions(i).interior_padding()) << "not implemented"; 7549b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower pad_interior[i] = padding.dimensions(i).interior_padding(); 7559b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower 7569b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower output_bounds[i] = pad_low[i] + input_bounds[i] + pad_high[i] + 7579b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower (input_bounds[i] - 1) * pad_interior[i]; 7589b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower } 7599b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower 7609b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower Array3D<float> result(output_bounds[0], output_bounds[1], output_bounds[2]); 7619b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower std::vector<int> indices = {0, 0, 0}; 7629b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower for (indices[0] = 0; indices[0] < output_bounds[0]; ++indices[0]) { 7639b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower for (indices[1] = 0; indices[1] < output_bounds[1]; ++indices[1]) { 7649b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower for (indices[2] = 0; indices[2] < output_bounds[2]; ++indices[2]) { 7659b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower float* value = &result(indices[0], indices[1], indices[2]); 7669b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower bool value_padded = false; 7679b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower for (int i = 0; i < 3; ++i) { 7689b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower bool in_low_padding = indices[i] < pad_low[i]; 7699b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower bool in_high_padding = indices[i] >= output_bounds[i] - pad_high[i]; 7709b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower if (in_low_padding || in_high_padding) { 7719b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower *value = pad; 7729b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower value_padded = true; 7739b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower } 7749b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower if (pad_interior[i] && 7759b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower (indices[i] - pad_low[i]) % (pad_interior[i] + 1)) { 7769b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower *value = pad; 7779b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower value_padded = true; 7789b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower } 7799b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower } 7809b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower if (value_padded) { 7819b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower continue; 7829b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower } 7839b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower *value = operand((indices[0] - pad_low[0]) / (pad_interior[0] + 1), 7849b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower (indices[1] - pad_low[1]) / (pad_interior[1] + 1), 7859b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower (indices[2] - pad_low[2]) / (pad_interior[2] + 1)); 7869b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower } 7879b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower } 7889b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower } 7899b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower return result; 7909b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower} 7919b28e435c1507be52129ec22d9d006d12b1e79f3A. Unique TensorFlower 792c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune/* static */ Array4D<float> ReferenceUtil::PadArray4D( 793c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune const Array4D<float>& operand, const PaddingConfig& padding, 794c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune const float pad) { 795c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune CHECK_EQ(padding.dimensions_size(), 4); 796c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune 797c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune const std::vector<int64> input_bounds = {operand.n1(), operand.n2(), 798c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune operand.n3(), operand.n4()}; 799c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune std::vector<int64> pad_low(4); 800c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune std::vector<int64> pad_high(4); 8018ac9eb6318084fcb3d9dc71a7a7fc05bf18048d8A. Unique TensorFlower std::vector<int64> pad_interior(4); 802c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune std::vector<int64> output_bounds(4); 803c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune for (int64 i = 0; i < 4; ++i) { 804c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune pad_low[i] = padding.dimensions(i).edge_padding_low(); 805c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune pad_high[i] = padding.dimensions(i).edge_padding_high(); 8068ac9eb6318084fcb3d9dc71a7a7fc05bf18048d8A. Unique TensorFlower CHECK_LE(0, padding.dimensions(i).interior_padding()) << "not implemented"; 8078ac9eb6318084fcb3d9dc71a7a7fc05bf18048d8A. Unique TensorFlower pad_interior[i] = padding.dimensions(i).interior_padding(); 808c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune 8098ac9eb6318084fcb3d9dc71a7a7fc05bf18048d8A. Unique TensorFlower output_bounds[i] = pad_low[i] + input_bounds[i] + pad_high[i] + 8108ac9eb6318084fcb3d9dc71a7a7fc05bf18048d8A. Unique TensorFlower (input_bounds[i] - 1) * pad_interior[i]; 811c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune } 812c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune 813c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune Array4D<float> result(output_bounds[0], output_bounds[1], output_bounds[2], 814c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune output_bounds[3]); 815c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune result.Each([&](tensorflow::gtl::ArraySlice<int64> indices, float* value) { 816c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune for (int i = 0; i < 4; ++i) { 817c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune bool in_low_padding = indices[i] < pad_low[i]; 818c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune bool in_high_padding = indices[i] >= output_bounds[i] - pad_high[i]; 819c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune if (in_low_padding || in_high_padding) { 820c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune *value = pad; 821c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune return; 822c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune } 8238ac9eb6318084fcb3d9dc71a7a7fc05bf18048d8A. Unique TensorFlower if (pad_interior[i] && 8248ac9eb6318084fcb3d9dc71a7a7fc05bf18048d8A. Unique TensorFlower (indices[i] - pad_low[i]) % (pad_interior[i] + 1)) { 8258ac9eb6318084fcb3d9dc71a7a7fc05bf18048d8A. Unique TensorFlower *value = pad; 8268ac9eb6318084fcb3d9dc71a7a7fc05bf18048d8A. Unique TensorFlower return; 8278ac9eb6318084fcb3d9dc71a7a7fc05bf18048d8A. Unique TensorFlower } 828c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune } 8298ac9eb6318084fcb3d9dc71a7a7fc05bf18048d8A. Unique TensorFlower *value = operand((indices[0] - pad_low[0]) / (pad_interior[0] + 1), 8308ac9eb6318084fcb3d9dc71a7a7fc05bf18048d8A. Unique TensorFlower (indices[1] - pad_low[1]) / (pad_interior[1] + 1), 8318ac9eb6318084fcb3d9dc71a7a7fc05bf18048d8A. Unique TensorFlower (indices[2] - pad_low[2]) / (pad_interior[2] + 1), 8328ac9eb6318084fcb3d9dc71a7a7fc05bf18048d8A. Unique TensorFlower (indices[3] - pad_low[3]) / (pad_interior[3] + 1)); 833c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune }); 834c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune return result; 835c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune} 836c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune 8371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} // namespace xla 838