reference_util.cc revision 6bbbd7e9d2016dfd201797d1f1354ccc48bd9e13
11e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 21e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 31e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsLicensed under the Apache License, Version 2.0 (the "License"); 41e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsyou may not use this file except in compliance with the License. 51e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsYou may obtain a copy of the License at 61e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 71e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins http://www.apache.org/licenses/LICENSE-2.0 81e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 91e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsUnless required by applicable law or agreed to in writing, software 101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsdistributed under the License is distributed on an "AS IS" BASIS, 111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsSee the License for the specific language governing permissions and 131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinslimitations under the License. 141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins==============================================================================*/ 151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/reference_util.h" 171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include <array> 191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/client/computation_builder.h" 211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" 221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/window_util.h" 231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/xla_data.pb.h" 241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/lib/math/math_util.h" 251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/platform/logging.h" 261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsnamespace xla { 281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::TransposeArray2D( 301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const Array2D<float>& operand) { 311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto result = MakeUnique<Array2D<float>>(operand.width(), operand.height()); 321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 w = 0; w < operand.width(); ++w) { 331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 h = 0; h < operand.height(); ++h) { 341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins (*result)(w, h) = operand(h, w); 351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return result; 391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MatmulArray2D( 421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const Array2D<float>& lhs, const Array2D<float>& rhs) { 431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins CHECK_EQ(lhs.width(), rhs.height()); 441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int m = lhs.height(); 451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int n = rhs.width(); 461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int k = lhs.width(); 471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto result = MakeUnique<Array2D<float>>(m, n); 481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // Because Eigen is a header-oriented library, make sure that the Eigen code 491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // is the same as the code used by the CPU backend (otherwise the linker will 501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // randomly pick *some* definition). 511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins __xla_cpu_runtime_EigenSingleThreadedMatMulF32( 521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins /*run_options_ptr=*/nullptr, result->data(), rhs.data(), lhs.data(), n, m, 531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins k, 541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins /*transpose_lhs=*/0, 551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins /*transpose_rhs=*/0); 561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return result; 571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<double>> ReferenceUtil::MatmulArray2D( 601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const Array2D<double>& lhs, const Array2D<double>& rhs) { 611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins CHECK_EQ(lhs.width(), rhs.height()); 621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int m = lhs.height(); 631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int n = rhs.width(); 641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int k = lhs.width(); 651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto result = MakeUnique<Array2D<double>>(m, n); 661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // Because Eigen is a header-oriented library, make sure that the Eigen code 671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // is the same as the code used by the CPU backend (otherwise the linker will 681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // randomly pick *some* definition). 691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins __xla_cpu_runtime_EigenSingleThreadedMatMulF64( 701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins /*run_options_ptr=*/nullptr, result->data(), rhs.data(), lhs.data(), n, m, 711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins k, 721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins /*transpose_lhs=*/0, 731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins /*transpose_rhs=*/0); 741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return result; 751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<double>> ReferenceUtil::Array2DF32ToF64( 781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const Array2D<float>& input) { 791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto result = MakeUnique<Array2D<double>>(input.height(), input.width()); 801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 rowno = 0; rowno < input.height(); ++rowno) { 811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 colno = 0; colno < input.height(); ++colno) { 821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins (*result)(rowno, colno) = input(rowno, colno); 831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return result; 861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::ConvArray4D( 891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const Array4D<float>& lhs, const Array4D<float>& rhs, 901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::pair<int64, int64> kernel_stride, Padding padding) { 911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return ConvArray4DGeneralDimensions( 921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins lhs, rhs, kernel_stride, padding, 931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ComputationBuilder::CreateDefaultConvDimensionNumbers()); 941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 967fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee/* static */ std::unique_ptr<Array4D<float>> 977fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong LeeReferenceUtil::SeparableConvArray4D(const Array4D<float>& input, 987fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee const Array4D<float>& depthwise_weights, 997fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee const Array4D<float>& pointwise_weights, 1007fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee std::pair<int64, int64> kernel_stride, 1017fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee Padding padding) { 1027fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee const int64 depth_multiplier = depthwise_weights.planes(); 1037fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee CHECK_EQ(pointwise_weights.depth(), input.depth() * depth_multiplier); 1047fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee 1057fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee // Combine the two weights by reducing the depth_multiplier, so that we can 1067fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee // apply a single convolution on the combined weights. 1077fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee Array4D<float> weights(pointwise_weights.planes(), input.depth(), 1087fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee depthwise_weights.height(), depthwise_weights.width()); 1097fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee for (int64 kx = 0; kx < depthwise_weights.width(); ++kx) { 1107fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee for (int64 ky = 0; ky < depthwise_weights.height(); ++ky) { 1117fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee for (int64 kz = 0; kz < input.depth(); ++kz) { 1127fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee for (int64 out = 0; out < pointwise_weights.planes(); ++out) { 1137fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee float weight = 0.0; 1147fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee for (int64 depth = 0; depth < depth_multiplier; ++depth) { 1157fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee weight += 1167fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee depthwise_weights(depth, kz, ky, kx) * 1177fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee pointwise_weights(out, depth + kz * depth_multiplier, 0, 0); 1187fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee } 1197fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee weights(out, kz, ky, kx) = weight; 1207fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee } 1217fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee } 1227fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee } 1237fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee } 1247fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee 1257fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee return ConvArray4D(input, weights, kernel_stride, padding); 1267fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee} 1277fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee 1281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ int64 ReferenceUtil::WindowCount(int64 unpadded_width, 1291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 window_len, int64 stride, 1301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins Padding padding) { 1311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins if (padding == Padding::kValid) { 1321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return window_util::StridedBound(unpadded_width, window_len, stride); 1331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 1341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return tensorflow::MathUtil::CeilOfRatio(unpadded_width, stride); 1351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 1361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1376bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::ReduceWindow2DAdd( 1386bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi const Array2D<float>& operand, float init, 1396bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi const tensorflow::gtl::ArraySlice<int64>& window, 1406bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding) { 1416bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi std::vector<int64> dim_lengths{operand.height(), operand.width()}; 1426bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding); 1436bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi 1446bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi std::vector<int64> window_counts(window.size(), 0); 1456bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi std::vector<int64> pad_low(window.size(), 0); 1466bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi for (int64 i = 0; i < window.size(); ++i) { 1476bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi window_counts[i] = 1486bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi WindowCount(dim_lengths[i], window[i], stride[i], padding); 1496bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi pad_low[i] = padding_both[i].first; 1506bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi } 1516bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi auto result = MakeUnique<Array2D<float>>(window_counts[0], window_counts[1]); 1526bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi 1536bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi // Do a full 2D reduce window. 1546bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi for (int64 i0 = 0; i0 < window_counts[0]; ++i0) { 1556bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi for (int64 i1 = 0; i1 < window_counts[1]; ++i1) { 1566bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi int64 i0_base = i0 * stride[0] - pad_low[0]; 1576bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi int64 i1_base = i1 * stride[1] - pad_low[1]; 1586bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi 1596bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi float val = init; 1606bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) { 1616bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi for (int64 i1_win = 0; i1_win < window[1]; ++i1_win) { 1626bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 && 1636bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi i0_base + i0_win < operand.n1() && 1646bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi i1_base + i1_win < operand.n2()) { 1656bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi val += operand(i0_base + i0_win, i1_base + i1_win); 1666bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi } 1676bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi } 1686bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi } 1696bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi (*result)(i0, i1) = val; 1706bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi } 1716bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi } 1726bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi return result; 1736bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi} 1746bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi 1751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::ReduceWindow4DAdd( 1761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const Array4D<float>& operand, float init, 1771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const tensorflow::gtl::ArraySlice<int64>& window, 1781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding) { 1791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(), 1801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins operand.n4()}; 1811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding); 1821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::vector<int64> window_counts(window.size(), 0); 1841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::vector<int64> pad_low(window.size(), 0); 1851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i = 0; i < window.size(); ++i) { 1861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins window_counts[i] = 1871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins WindowCount(dim_lengths[i], window[i], stride[i], padding); 1881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins pad_low[i] = padding_both[i].first; 1891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 1901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto result = MakeUnique<Array4D<float>>(window_counts[0], window_counts[1], 1911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins window_counts[2], window_counts[3]); 1921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // Do a full 4D reduce window. 1931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i0 = 0; i0 < window_counts[0]; ++i0) { 1941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i1 = 0; i1 < window_counts[1]; ++i1) { 1951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i2 = 0; i2 < window_counts[2]; ++i2) { 1961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i3 = 0; i3 < window_counts[3]; ++i3) { 1971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 i0_base = i0 * stride[0] - pad_low[0]; 1981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 i1_base = i1 * stride[1] - pad_low[1]; 1991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 i2_base = i2 * stride[2] - pad_low[2]; 2001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 i3_base = i3 * stride[3] - pad_low[3]; 2011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 2021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins float val = init; 2031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) { 2041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i1_win = 0; i1_win < window[1]; ++i1_win) { 2051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i2_win = 0; i2_win < window[2]; ++i2_win) { 2061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i3_win = 0; i3_win < window[3]; ++i3_win) { 2071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 && 2081e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins i2_base + i2_win >= 0 && i3_base + i3_win >= 0 && 2091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins i0_base + i0_win < operand.n1() && 2101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins i1_base + i1_win < operand.n2() && 2111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins i2_base + i2_win < operand.n3() && 2121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins i3_base + i3_win < operand.n4()) { 2131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins val += operand(i0_base + i0_win, i1_base + i1_win, 2141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins i2_base + i2_win, i3_base + i3_win); 2151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 2161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 2171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 2181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 2191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 2201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins (*result)(i0, i1, i2, i3) = val; 2211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 2221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 2231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 2241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 2251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return result; 2261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 2271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 2281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array4D<float>> 2291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsReferenceUtil::SelectAndScatter4DGePlus( 2301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const Array4D<float>& operand, const Array4D<float>& source, float init, 2311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const tensorflow::gtl::ArraySlice<int64>& window, 2321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const tensorflow::gtl::ArraySlice<int64>& stride, bool same_padding) { 2331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins Padding padding = same_padding ? Padding::kSame : Padding::kValid; 2341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto result = MakeUnique<Array4D<float>>(operand.n1(), operand.n2(), 2351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins operand.n3(), operand.n4()); 2361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(), 2371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins operand.n4()}; 2381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding); 2391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // Fill the output, with the initial value. 2401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins result->Fill(init); 2411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 2421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::vector<int64> window_counts(window.size(), 0); 2431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::vector<int64> pad_low(window.size(), 0); 2441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i = 0; i < window.size(); ++i) { 2451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins window_counts[i] = 2461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins WindowCount(dim_lengths[i], window[i], stride[i], padding); 2471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins pad_low[i] = padding_both[i].first; 2481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 2491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins CHECK_EQ(window_counts[0], source.n1()); 2501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins CHECK_EQ(window_counts[1], source.n2()); 2511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins CHECK_EQ(window_counts[2], source.n3()); 2521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins CHECK_EQ(window_counts[3], source.n4()); 2531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 2541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // Do a full 4D select and Scatter. 2551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i0 = 0; i0 < window_counts[0]; ++i0) { 2561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i1 = 0; i1 < window_counts[1]; ++i1) { 2571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i2 = 0; i2 < window_counts[2]; ++i2) { 2581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i3 = 0; i3 < window_counts[3]; ++i3) { 2591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // Now we are inside a window and need to find the max and the argmax. 2601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 i0_base = i0 * stride[0] - pad_low[0]; 2611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 i1_base = i1 * stride[1] - pad_low[1]; 2621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 i2_base = i2 * stride[2] - pad_low[2]; 2631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 i3_base = i3 * stride[3] - pad_low[3]; 2641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 scatter_0 = (i0_base >= 0) ? i0_base : 0; 2651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 scatter_1 = (i1_base >= 0) ? i1_base : 0; 2661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 scatter_2 = (i2_base >= 0) ? i2_base : 0; 2671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 scatter_3 = (i3_base >= 0) ? i3_base : 0; 2681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins float val = operand(scatter_0, scatter_1, scatter_2, scatter_3); 2691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) { 2701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i1_win = 0; i1_win < window[1]; ++i1_win) { 2711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i2_win = 0; i2_win < window[2]; ++i2_win) { 2721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i3_win = 0; i3_win < window[3]; ++i3_win) { 2731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 && 2741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins i2_base + i2_win >= 0 && i3_base + i3_win >= 0 && 2751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins i0_base + i0_win < operand.n1() && 2761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins i1_base + i1_win < operand.n2() && 2771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins i2_base + i2_win < operand.n3() && 2781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins i3_base + i3_win < operand.n4()) { 2791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins float tmp = operand(i0_base + i0_win, i1_base + i1_win, 2801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins i2_base + i2_win, i3_base + i3_win); 2811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins if (tmp >= val) { 2821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins val = tmp; 2831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins scatter_0 = i0_base + i0_win; 2841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins scatter_1 = i1_base + i1_win; 2851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins scatter_2 = i2_base + i2_win; 2861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins scatter_3 = i3_base + i3_win; 2871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 2881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 2891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 2901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 2911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 2921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 2931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins (*result)(scatter_0, scatter_1, scatter_2, scatter_3) += 2941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins source(i0, i1, i2, i3); 2951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 2961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 2971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 2981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 2991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return result; 3001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 3011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 3021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array4D<float>> 3031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsReferenceUtil::ConvArray4DGeneralDimensions( 3041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const Array4D<float>& lhs, const Array4D<float>& rhs, 3051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::pair<int64, int64> kernel_stride, Padding padding, 3061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ConvolutionDimensionNumbers dimension_numbers) { 3071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return ConvArray4DGeneralDimensionsDilated(lhs, rhs, kernel_stride, padding, 3081e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins {1, 1}, {1, 1}, dimension_numbers); 3091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 3101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 3111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array4D<float>> 3121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsReferenceUtil::ConvArray4DGeneralDimensionsDilated( 3131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const Array4D<float>& lhs, const Array4D<float>& rhs, 3141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::pair<int64, int64> kernel_stride, Padding padding, 3151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::pair<int64, int64> lhs_dilation, std::pair<int64, int64> rhs_dilation, 3161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ConvolutionDimensionNumbers dnums) { 3171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::array<int64, 4> lhs_dimensions{{lhs.n1(), lhs.n2(), lhs.n3(), lhs.n4()}}; 3181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::array<int64, 4> rhs_dimensions{{rhs.n1(), rhs.n2(), rhs.n3(), rhs.n4()}}; 3191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 3201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const int64 ksy = kernel_stride.first; 3211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const int64 ksx = kernel_stride.second; 3221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const int64 dy = lhs_dilation.first; 3231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const int64 dx = lhs_dilation.second; 3241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const int64 dky = rhs_dilation.first; 3251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const int64 dkx = rhs_dilation.second; 3261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins CHECK_GE(dky, 1); 3271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins CHECK_GE(dkx, 1); 3281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins CHECK_GE(dy, 1); 3291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins CHECK_GE(dx, 1); 3301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 3311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // Get all dimension sizes in lhs and rhs based on the given convolution 3321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // dimension configuration. 3331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const int64 ix = window_util::DilatedBound( 3341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins lhs_dimensions[dnums.spatial_dimensions(1)], dx); 3351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const int64 iy = window_util::DilatedBound( 3361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins lhs_dimensions[dnums.spatial_dimensions(0)], dy); 3371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const int64 iz = lhs_dimensions[dnums.feature_dimension()]; 3381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const int64 samples = lhs_dimensions[dnums.batch_dimension()]; 3391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const int64 kx = window_util::DilatedBound( 3401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins rhs_dimensions[dnums.kernel_spatial_dimensions(1)], dkx); 3411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const int64 ky = window_util::DilatedBound( 3421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins rhs_dimensions[dnums.kernel_spatial_dimensions(0)], dky); 3431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const int64 oz = rhs_dimensions[dnums.kernel_output_feature_dimension()]; 3441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins { 3451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const int64 kiz = rhs_dimensions[dnums.kernel_input_feature_dimension()]; 3461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins CHECK_EQ(kiz, iz); 3471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 3481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 3491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins if (padding == Padding::kSame) { 3501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // We reject same padding with kernel striding, since it's somewhat 3511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // nonsensical. We can always follow up to implement this with the desired 3521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // semantics if anybody actually uses it. 3531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins CHECK_EQ(1, ksy); 3541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins CHECK_EQ(1, ksx); 3551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 3561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 3571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const int64 ox = 3581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins padding == Padding::kSame ? ix : window_util::StridedBound(ix, kx, ksx); 3591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const int64 oy = 3601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins padding == Padding::kSame ? iy : window_util::StridedBound(iy, ky, ksy); 3611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const int64 istartx = 3621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins padding == Padding::kValid ? 0 : kx % 2 == 0 ? -(kx / 2 - 1) : -kx / 2; 3631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const int64 istarty = 3641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins padding == Padding::kValid ? 0 : ky % 2 == 0 ? -(ky / 2 - 1) : -ky / 2; 3651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // Create the output result array and reset the values to 0. 3661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::array<int64, 4> result_dimensions; 3671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins result_dimensions[dnums.batch_dimension()] = samples; 3681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins result_dimensions[dnums.feature_dimension()] = oz; 3691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins result_dimensions[dnums.spatial_dimensions(0)] = oy; 3701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins result_dimensions[dnums.spatial_dimensions(1)] = ox; 3711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto result = 3721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins MakeUnique<Array4D<float>>(result_dimensions[0], result_dimensions[1], 3731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins result_dimensions[2], result_dimensions[3]); 3741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins result->Fill(0.0); 3751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 3767135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar const auto is_int32 = [](int64 x) { 3777135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar return x >= std::numeric_limits<int32>::min() && 3787135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar x <= std::numeric_limits<int32>::max(); 3797135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar }; 3807135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar 3817135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar // 64-bit idiv/mod are much more expensive x86-64 than 32-bit idiv/imod (at 3827135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar // least on x86-64), so we avoid them where possible. 3837135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar const auto fast_idiv64 = [&](int64 a, int64 b) { 3847135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar if (is_int32(a) && is_int32(b)) { 3857135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar return static_cast<int64>(static_cast<int32>(a) / static_cast<int32>(b)); 3867135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar } 3877135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar return a / b; 3887135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar }; 3897135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar const auto fast_imod64 = [&](int64 a, int64 b) { 3907135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar if (is_int32(a) && is_int32(b)) { 3917135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar return static_cast<int64>(static_cast<int32>(a) % static_cast<int32>(b)); 3927135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar } 3937135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar return a % b; 3947135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar }; 3957135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar 3961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // Lambda to access the lhs operand at the given 4D index. 3971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const auto lhs_element = [&](int64 batch, int64 feature, int64 height, 3981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 width) { 3997135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar if (fast_imod64(height, dy) != 0 || fast_imod64(width, dx) != 0) { 4001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return 0.0f; 4011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 4021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 4031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::array<int64, 4> index; 4041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins index[dnums.batch_dimension()] = batch; 4051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins index[dnums.feature_dimension()] = feature; 4067135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar index[dnums.spatial_dimensions(0)] = fast_idiv64(height, dy); 4077135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar index[dnums.spatial_dimensions(1)] = fast_idiv64(width, dx); 4081e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return lhs(index[0], index[1], index[2], index[3]); 4091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins }; 4101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 411f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar // Lambda to access the rhs operand at the given 4D index. height_over_dky 412f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar // should be equal to height / dky, and width_over_dkx should be equal to 413f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar // width / dkx. (This is an optimization to avoid doing divisions.) 414f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar const auto rhs_element = [&]( 415f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar int64 kernel_output_feature, int64 kernel_input_feature, int64 height, 416f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar int64 width, int64 height_over_dky, int64 width_over_dkx) { 417f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar DCHECK_EQ(height % dky, 0); 418f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar DCHECK_EQ(width % dkx, 0); 419f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar DCHECK_EQ(height / dky, height_over_dky); 420f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar DCHECK_EQ(width / dkx, width_over_dkx); 421f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar 4221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::array<int64, 4> index; 4231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins index[dnums.kernel_output_feature_dimension()] = kernel_output_feature; 4241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins index[dnums.kernel_input_feature_dimension()] = kernel_input_feature; 425f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar index[dnums.kernel_spatial_dimensions(0)] = height_over_dky; 426f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar index[dnums.kernel_spatial_dimensions(1)] = width_over_dkx; 4271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return rhs(index[0], index[1], index[2], index[3]); 4281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins }; 4291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 4301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // Lambda to access the result data at the given 4D index. 4311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const auto result_element = [&](int64 batch, int64 kernel_output_feature, 4321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 height, int64 width) -> float& { 4331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::array<int64, 4> index; 4341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins index[dnums.batch_dimension()] = batch; 4351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins index[dnums.feature_dimension()] = kernel_output_feature; 4361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins index[dnums.spatial_dimensions(0)] = height; 4371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins index[dnums.spatial_dimensions(1)] = width; 4381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return (*result)(index[0], index[1], index[2], index[3]); 4391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins }; 4401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 4411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 oyi = 0; oyi < oy; ++oyi) { 4421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 oxi = 0; oxi < ox; ++oxi) { 4431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 sample = 0; sample < samples; ++sample) { 4441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 izi = 0; izi < iz; ++izi) { 4451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 ozi = 0; ozi < oz; ++ozi) { 446f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar for (int64 kyi = 0, kyi_over_dky = 0; kyi < ky; 447f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar kyi += dky, kyi_over_dky++) { 448f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar for (int64 kxi = 0, kxi_over_dkx = 0; kxi < kx; 449f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar kxi += dkx, kxi_over_dkx++) { 4501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 iyi = istarty + ksy * oyi + kyi; 4511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 ixi = istartx + ksx * oxi + kxi; 4521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins float input = (iyi >= iy || ixi >= ix || iyi < 0 || ixi < 0) 4531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ? 0.0 4541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins : lhs_element(sample, izi, iyi, ixi); 455f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar float gain = 456f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar rhs_element(ozi, izi, kyi, kxi, kyi_over_dky, kxi_over_dkx); 4571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins float addend = input * gain; 4581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins result_element(sample, ozi, oyi, oxi) += addend; 4591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 4601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 4611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 4621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 4631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 4641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 4651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 4661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return result; 4671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 4681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 4691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<std::vector<float>> 4701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsReferenceUtil::ReduceToColArray2D( 4711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const Array2D<float>& matrix, float init, 4721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::function<float(float, float)> reduce_function) { 4731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 rows = matrix.height(); 4741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 cols = matrix.width(); 4751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto result = MakeUnique<std::vector<float>>(); 4761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i = 0; i < rows; ++i) { 4771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins float acc = init; 4781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 j = 0; j < cols; ++j) { 4791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins acc = reduce_function(acc, matrix(i, j)); 4801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 4811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins result->push_back(acc); 4821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 4831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return result; 4841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 4851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 4861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<std::vector<float>> 4871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsReferenceUtil::ReduceToRowArray2D( 4881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const Array2D<float>& matrix, float init, 4891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::function<float(float, float)> reduce_function) { 4901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 rows = matrix.height(); 4911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 cols = matrix.width(); 4921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto result = MakeUnique<std::vector<float>>(); 4931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i = 0; i < cols; ++i) { 4941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins float acc = init; 4951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 j = 0; j < rows; ++j) { 4961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins acc = reduce_function(acc, matrix(j, i)); 4971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 4981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins result->push_back(acc); 4991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 5001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return result; 5011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 5021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 5031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/*static*/ std::vector<float> ReferenceUtil::Reduce4DTo1D( 5041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const Array4D<float>& array, float init, 5051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins tensorflow::gtl::ArraySlice<int64> dims, 5061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::function<float(float, float)> reduce_function) { 5071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::vector<float> result; 5081e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins CHECK_EQ(dims.size(), 3); 5091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const std::set<int64> dim_set(dims.begin(), dims.end()); 5101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins CHECK_EQ(dim_set.size(), 3); 5111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 a0 = 0; a0 == 0 || (!dim_set.count(0) && a0 < array.n1()); ++a0) { 5121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 a1 = 0; a1 == 0 || (!dim_set.count(1) && a1 < array.n2()); 5131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ++a1) { 5141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 a2 = 0; a2 == 0 || (!dim_set.count(2) && a2 < array.n3()); 5151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ++a2) { 5161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 a3 = 0; a3 == 0 || (!dim_set.count(3) && a3 < array.n4()); 5171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ++a3) { 5181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins float accumulator = init; 5191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i0 = 0; i0 == 0 || (dim_set.count(0) && i0 < array.n1()); 5201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ++i0) { 5211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i1 = 0; i1 == 0 || (dim_set.count(1) && i1 < array.n2()); 5221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ++i1) { 5231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i2 = 0; 5241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins i2 == 0 || (dim_set.count(2) && i2 < array.n3()); ++i2) { 5251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i3 = 0; 5261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins i3 == 0 || (dim_set.count(3) && i3 < array.n4()); ++i3) { 5271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins accumulator = reduce_function( 5281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins accumulator, array(a0 + i0, a1 + i1, a2 + i2, a3 + i3)); 5291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 5301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 5311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 5321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 5331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins result.push_back(accumulator); 5341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 5351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 5361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 5371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 5381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return result; 5391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 5401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 5411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::Reduce3DTo2D( 5421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const Array3D<float>& array, float init, 5431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins tensorflow::gtl::ArraySlice<int64> dims, 5441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::function<float(float, float)> reduce_function) { 5451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins CHECK_EQ(dims.size(), 1); 5461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 rows = dims[0] == 0 ? array.n2() : array.n1(); 5471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 cols = dims[0] == 2 ? array.n2() : array.n3(); 5481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto result = MakeUnique<Array2D<float>>(rows, cols); 5491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins result->Fill(init); 5501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int i0 = 0; i0 < array.n1(); ++i0) { 5511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int i1 = 0; i1 < array.n2(); ++i1) { 5521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int i2 = 0; i2 < array.n3(); ++i2) { 5531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 row = dims[0] == 0 ? i1 : i0; 5541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 col = dims[0] == 2 ? i1 : i2; 5551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins (*result)(row, col) = 5561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins reduce_function((*result)(row, col), array(i0, i1, i2)); 5571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 5581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 5591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 5601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return result; 5611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 5621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 5631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MapArray2D( 5641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const Array2D<float>& matrix, 5651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const std::function<float(float)>& map_function) { 5661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 rows = matrix.height(); 5671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 cols = matrix.width(); 5681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto result = MakeUnique<Array2D<float>>(rows, cols); 5691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i = 0; i < rows; ++i) { 5701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 j = 0; j < cols; ++j) { 5711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins (*result)(i, j) = map_function(matrix(i, j)); 5721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 5731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 5741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return result; 5751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 5761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 5771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MapArray2D( 5781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const Array2D<float>& lhs, const Array2D<float>& rhs, 5791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const std::function<float(float, float)>& map_function) { 5801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins CHECK_EQ(lhs.height(), rhs.height()); 5811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins CHECK_EQ(lhs.width(), rhs.width()); 5821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 rows = lhs.height(); 5831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 cols = rhs.width(); 5841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto result = MakeUnique<Array2D<float>>(rows, cols); 5851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i = 0; i < rows; ++i) { 5861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 j = 0; j < cols; ++j) { 5871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins (*result)(i, j) = map_function(lhs(i, j), rhs(i, j)); 5881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 5891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 5901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return result; 5911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 5921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 5931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MapWithIndexArray2D( 5941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const Array2D<float>& matrix, 5951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const std::function<float(float, int64, int64)>& map_function) { 5961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 rows = matrix.height(); 5971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 cols = matrix.width(); 5981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto result = MakeUnique<Array2D<float>>(rows, cols); 5991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i = 0; i < rows; ++i) { 6001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 j = 0; j < cols; ++j) { 6011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins (*result)(i, j) = map_function(matrix(i, j), i, j); 6021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 6031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 6041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return result; 6051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 6061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 6071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::PadArray2D( 6081e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const Array2D<float>& operand, const PaddingConfig& padding, 6091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const float pad) { 6101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 in0 = operand.n1(); 6111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 high_padding0 = padding.dimensions(0).edge_padding_high(); 6121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 low_padding0 = padding.dimensions(0).edge_padding_low(); 6131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 interior_padding0 = padding.dimensions(0).interior_padding(); 6141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 out0 = 6151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins in0 + low_padding0 + high_padding0 + (in0 - 1) * interior_padding0; 6165aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan 6171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 in1 = operand.n2(); 6181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 high_padding1 = padding.dimensions(1).edge_padding_high(); 6191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 low_padding1 = padding.dimensions(1).edge_padding_low(); 6201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 interior_padding1 = padding.dimensions(1).interior_padding(); 6211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 out1 = 6221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins in1 + low_padding1 + high_padding1 + (in1 - 1) * interior_padding1; 6235aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan 6241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto result = MakeUnique<Array2D<float>>(out0, out1); 6251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins result->Fill(pad); 6265aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan int64 o0 = low_padding0; 6275aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan for (int64 i0 = 0; i0 < in0; ++i0) { 6285aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan int64 o1 = low_padding1; 6295aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan for (int64 i1 = 0; i1 < in1; ++i1) { 6305aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan if (o0 >= 0 && o1 >= 0 && o0 < out0 && o1 < out1) { 6315aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan (*result)(o0, o1) = operand(i0, i1); 6325aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan } 6335aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan o1 += interior_padding1 + 1; 6341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 6355aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan o0 += interior_padding0 + 1; 6361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 6371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return result; 6381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 6391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 6401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} // namespace xla 641