reference_util.cc revision 0034029ac66bb60f272fa1aae05eca5dd9d210d1
11e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 21e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 31e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsLicensed under the Apache License, Version 2.0 (the "License"); 41e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsyou may not use this file except in compliance with the License. 51e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsYou may obtain a copy of the License at 61e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 71e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins http://www.apache.org/licenses/LICENSE-2.0 81e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 91e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsUnless required by applicable law or agreed to in writing, software 101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsdistributed under the License is distributed on an "AS IS" BASIS, 111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsSee the License for the specific language governing permissions and 131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinslimitations under the License. 141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins==============================================================================*/ 151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/reference_util.h" 171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include <array> 19f4b8d21b8e41636b6e61f0a1de753430108d2ee7A. Unique TensorFlower#include <utility> 201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/client/computation_builder.h" 221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" 231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/window_util.h" 241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/xla_data.pb.h" 251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/lib/math/math_util.h" 261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/platform/logging.h" 271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsnamespace xla { 291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::TransposeArray2D( 311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const Array2D<float>& operand) { 321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto result = MakeUnique<Array2D<float>>(operand.width(), operand.height()); 331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 w = 0; w < operand.width(); ++w) { 341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 h = 0; h < operand.height(); ++h) { 351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins (*result)(w, h) = operand(h, w); 361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return result; 401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MatmulArray2D( 431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const Array2D<float>& lhs, const Array2D<float>& rhs) { 441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins CHECK_EQ(lhs.width(), rhs.height()); 451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int m = lhs.height(); 461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int n = rhs.width(); 471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int k = lhs.width(); 481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto result = MakeUnique<Array2D<float>>(m, n); 491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // Because Eigen is a header-oriented library, make sure that the Eigen code 501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // is the same as the code used by the CPU backend (otherwise the linker will 511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // randomly pick *some* definition). 521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins __xla_cpu_runtime_EigenSingleThreadedMatMulF32( 531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins /*run_options_ptr=*/nullptr, result->data(), rhs.data(), lhs.data(), n, m, 541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins k, 551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins /*transpose_lhs=*/0, 561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins /*transpose_rhs=*/0); 571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return result; 581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<double>> ReferenceUtil::MatmulArray2D( 611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const Array2D<double>& lhs, const Array2D<double>& rhs) { 621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins CHECK_EQ(lhs.width(), rhs.height()); 631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int m = lhs.height(); 641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int n = rhs.width(); 651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int k = lhs.width(); 661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto result = MakeUnique<Array2D<double>>(m, n); 671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // Because Eigen is a header-oriented library, make sure that the Eigen code 681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // is the same as the code used by the CPU backend (otherwise the linker will 691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // randomly pick *some* definition). 701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins __xla_cpu_runtime_EigenSingleThreadedMatMulF64( 711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins /*run_options_ptr=*/nullptr, result->data(), rhs.data(), lhs.data(), n, m, 721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins k, 731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins /*transpose_lhs=*/0, 741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins /*transpose_rhs=*/0); 751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return result; 761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<double>> ReferenceUtil::Array2DF32ToF64( 791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const Array2D<float>& input) { 801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto result = MakeUnique<Array2D<double>>(input.height(), input.width()); 811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 rowno = 0; rowno < input.height(); ++rowno) { 821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 colno = 0; colno < input.height(); ++colno) { 831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins (*result)(rowno, colno) = input(rowno, colno); 841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return result; 871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::ConvArray4D( 901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const Array4D<float>& lhs, const Array4D<float>& rhs, 911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::pair<int64, int64> kernel_stride, Padding padding) { 921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return ConvArray4DGeneralDimensions( 931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins lhs, rhs, kernel_stride, padding, 941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ComputationBuilder::CreateDefaultConvDimensionNumbers()); 951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 977fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee/* static */ std::unique_ptr<Array4D<float>> 987fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong LeeReferenceUtil::SeparableConvArray4D(const Array4D<float>& input, 997fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee const Array4D<float>& depthwise_weights, 1007fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee const Array4D<float>& pointwise_weights, 1017fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee std::pair<int64, int64> kernel_stride, 1027fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee Padding padding) { 1037fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee const int64 depth_multiplier = depthwise_weights.planes(); 1047fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee CHECK_EQ(pointwise_weights.depth(), input.depth() * depth_multiplier); 1057fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee 1067fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee // Combine the two weights by reducing the depth_multiplier, so that we can 1077fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee // apply a single convolution on the combined weights. 1087fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee Array4D<float> weights(pointwise_weights.planes(), input.depth(), 1097fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee depthwise_weights.height(), depthwise_weights.width()); 1107fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee for (int64 kx = 0; kx < depthwise_weights.width(); ++kx) { 1117fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee for (int64 ky = 0; ky < depthwise_weights.height(); ++ky) { 1127fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee for (int64 kz = 0; kz < input.depth(); ++kz) { 1137fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee for (int64 out = 0; out < pointwise_weights.planes(); ++out) { 1147fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee float weight = 0.0; 1157fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee for (int64 depth = 0; depth < depth_multiplier; ++depth) { 1167fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee weight += 1177fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee depthwise_weights(depth, kz, ky, kx) * 1187fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee pointwise_weights(out, depth + kz * depth_multiplier, 0, 0); 1197fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee } 1207fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee weights(out, kz, ky, kx) = weight; 1217fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee } 1227fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee } 1237fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee } 1247fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee } 1257fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee 1267fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee return ConvArray4D(input, weights, kernel_stride, padding); 1277fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee} 1287fce98639e574b159a492c0b3ba6e7343272fa06HyoukJoong Lee 1291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ int64 ReferenceUtil::WindowCount(int64 unpadded_width, 1301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 window_len, int64 stride, 1311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins Padding padding) { 1321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins if (padding == Padding::kValid) { 1331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return window_util::StridedBound(unpadded_width, window_len, stride); 1341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 1351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return tensorflow::MathUtil::CeilOfRatio(unpadded_width, stride); 1361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 1371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1380034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi/* static */ std::unique_ptr<std::vector<float>> 1390034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo OguntebiReferenceUtil::ReduceWindow1DGeneric( 1400034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi const tensorflow::gtl::ArraySlice<float>& operand, float init, 1410034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi const std::function<float(float, float)>& reduce_func, 1420034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi const tensorflow::gtl::ArraySlice<int64>& window, 1430034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding) { 1440034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi std::vector<int64> dim_lengths{static_cast<int64>(operand.size())}; 1450034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding); 1460034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi 1470034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi std::vector<int64> window_counts(window.size(), 0); 1480034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi std::vector<int64> pad_low(window.size(), 0); 1490034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi for (int64 i = 0; i < window.size(); ++i) { 1500034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi window_counts[i] = 1510034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi WindowCount(dim_lengths[i], window[i], stride[i], padding); 1520034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi pad_low[i] = padding_both[i].first; 1530034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi } 1540034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi auto result = MakeUnique<std::vector<float>>(window_counts[0]); 1550034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi 1560034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi // Do a full 1D reduce window. 1570034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi for (int64 i0 = 0; i0 < window_counts[0]; ++i0) { 1580034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi int64 i0_base = i0 * stride[0] - pad_low[0]; 1590034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi 1600034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi float val = init; 1610034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) { 1620034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi if (i0_base + i0_win >= 0 && i0_base + i0_win < dim_lengths[0]) { 1630034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi val = reduce_func(val, operand[i0_base + i0_win]); 1640034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi } 1650034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi } 1660034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi (*result)[i0] = val; 1670034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi } 1680034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi return result; 1690034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi} 1700034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi 1710034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi/* static */ std::unique_ptr<std::vector<float>> 1720034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo OguntebiReferenceUtil::ReduceWindow1DAdd( 1730034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi const tensorflow::gtl::ArraySlice<float>& operand, float init, 1740034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi const tensorflow::gtl::ArraySlice<int64>& window, 1750034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding) { 1760034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; }; 1770034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi return ReduceWindow1DGeneric(operand, init, add_reduce, window, stride, 1780034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi padding); 1790034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi} 1800034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi 1816bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::ReduceWindow2DAdd( 1826bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi const Array2D<float>& operand, float init, 1836bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi const tensorflow::gtl::ArraySlice<int64>& window, 1846bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding) { 1856bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi std::vector<int64> dim_lengths{operand.height(), operand.width()}; 1866bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding); 1876bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi 1886bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi std::vector<int64> window_counts(window.size(), 0); 1896bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi std::vector<int64> pad_low(window.size(), 0); 1906bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi for (int64 i = 0; i < window.size(); ++i) { 1916bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi window_counts[i] = 1926bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi WindowCount(dim_lengths[i], window[i], stride[i], padding); 1936bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi pad_low[i] = padding_both[i].first; 1946bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi } 1956bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi auto result = MakeUnique<Array2D<float>>(window_counts[0], window_counts[1]); 1966bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi 1976bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi // Do a full 2D reduce window. 1986bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi for (int64 i0 = 0; i0 < window_counts[0]; ++i0) { 1996bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi for (int64 i1 = 0; i1 < window_counts[1]; ++i1) { 2006bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi int64 i0_base = i0 * stride[0] - pad_low[0]; 2016bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi int64 i1_base = i1 * stride[1] - pad_low[1]; 2026bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi 2036bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi float val = init; 2046bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) { 2056bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi for (int64 i1_win = 0; i1_win < window[1]; ++i1_win) { 2066bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 && 2076bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi i0_base + i0_win < operand.n1() && 2086bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi i1_base + i1_win < operand.n2()) { 2096bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi val += operand(i0_base + i0_win, i1_base + i1_win); 2106bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi } 2116bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi } 2126bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi } 2136bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi (*result)(i0, i1) = val; 2146bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi } 2156bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi } 2166bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi return result; 2176bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi} 2186bbbd7e9d2016dfd201797d1f1354ccc48bd9e13Tayo Oguntebi 2192d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi/* static */ std::unique_ptr<Array4D<float>> 2202d69270342d2a5e46446e02e9273e7da79f00accTayo OguntebiReferenceUtil::ReduceWindow4DGeneric( 2211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const Array4D<float>& operand, float init, 2222d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi const std::function<float(float, float)>& reduce_func, 2231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const tensorflow::gtl::ArraySlice<int64>& window, 2241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding) { 2251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(), 2261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins operand.n4()}; 227f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune return ReduceWindow4DGeneric( 228f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune operand, init, reduce_func, window, stride, 229f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune xla::MakePadding(dim_lengths, window, stride, padding)); 230f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune} 231f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune 232f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune/* static */ std::unique_ptr<Array4D<float>> 233f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt RouneReferenceUtil::ReduceWindow4DGeneric( 234f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune const Array4D<float>& operand, float init, 235f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune const std::function<float(float, float)>& reduce_func, 236f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune const tensorflow::gtl::ArraySlice<int64>& window, 237f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune const tensorflow::gtl::ArraySlice<int64>& stride, 238f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune const tensorflow::gtl::ArraySlice<std::pair<int64, int64>>& padding) { 239f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(), 240f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune operand.n4()}; 2411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 2421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::vector<int64> window_counts(window.size(), 0); 2431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::vector<int64> pad_low(window.size(), 0); 2441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i = 0; i < window.size(); ++i) { 245f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune int64 padded_width = padding[i].first + dim_lengths[i] + padding[i].second; 2461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins window_counts[i] = 247f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune window_util::StridedBound(padded_width, window[i], stride[i]); 248f8051d04d982b5877717f1e4145488263aa69c68Bjarke Hammersholt Roune pad_low[i] = padding[i].first; 2491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 2501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto result = MakeUnique<Array4D<float>>(window_counts[0], window_counts[1], 2511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins window_counts[2], window_counts[3]); 2521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // Do a full 4D reduce window. 2531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i0 = 0; i0 < window_counts[0]; ++i0) { 2541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i1 = 0; i1 < window_counts[1]; ++i1) { 2551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i2 = 0; i2 < window_counts[2]; ++i2) { 2561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i3 = 0; i3 < window_counts[3]; ++i3) { 2571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 i0_base = i0 * stride[0] - pad_low[0]; 2581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 i1_base = i1 * stride[1] - pad_low[1]; 2591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 i2_base = i2 * stride[2] - pad_low[2]; 2601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 i3_base = i3 * stride[3] - pad_low[3]; 2611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 2621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins float val = init; 2631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) { 2641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i1_win = 0; i1_win < window[1]; ++i1_win) { 2651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i2_win = 0; i2_win < window[2]; ++i2_win) { 2661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i3_win = 0; i3_win < window[3]; ++i3_win) { 2671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 && 2681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins i2_base + i2_win >= 0 && i3_base + i3_win >= 0 && 2691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins i0_base + i0_win < operand.n1() && 2701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins i1_base + i1_win < operand.n2() && 2711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins i2_base + i2_win < operand.n3() && 2721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins i3_base + i3_win < operand.n4()) { 2732d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi val = reduce_func( 2742d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi val, operand(i0_base + i0_win, i1_base + i1_win, 2752d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi i2_base + i2_win, i3_base + i3_win)); 2761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 2771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 2781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 2791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 2801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 2811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins (*result)(i0, i1, i2, i3) = val; 2821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 2831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 2841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 2851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 2861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return result; 2871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 2881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 2892d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi/* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::ReduceWindow4DAdd( 2902d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi const Array4D<float>& operand, float init, 2912d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi const tensorflow::gtl::ArraySlice<int64>& window, 2922d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding) { 2932d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; }; 2942d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi return ReduceWindow4DGeneric(operand, init, add_reduce, window, stride, 2952d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi padding); 2962d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi} 2972d69270342d2a5e46446e02e9273e7da79f00accTayo Oguntebi 2981464b9930de871fd11870941963253670f737c23A. Unique TensorFlower/* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::BatchNorm4D( 2991464b9930de871fd11870941963253670f737c23A. Unique TensorFlower const Array4D<float>& input, const Array4D<float>& mean, 3001464b9930de871fd11870941963253670f737c23A. Unique TensorFlower const Array4D<float>& var, const Array4D<float>& scale, 3011464b9930de871fd11870941963253670f737c23A. Unique TensorFlower const Array4D<float>& offset, float epsilon) { 3021464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto normalized = 3031464b9930de871fd11870941963253670f737c23A. Unique TensorFlower *MapArray4D(input, mean, [](float a, float b) { return a - b; }); 3041464b9930de871fd11870941963253670f737c23A. Unique TensorFlower normalized = *MapArray4D(normalized, var, [&](float a, float b) { 3051464b9930de871fd11870941963253670f737c23A. Unique TensorFlower return a / std::sqrt(b + epsilon); 3061464b9930de871fd11870941963253670f737c23A. Unique TensorFlower }); 3071464b9930de871fd11870941963253670f737c23A. Unique TensorFlower normalized = 3081464b9930de871fd11870941963253670f737c23A. Unique TensorFlower *MapArray4D(normalized, scale, [](float a, float b) { return a * b; }); 3091464b9930de871fd11870941963253670f737c23A. Unique TensorFlower return MapArray4D(normalized, offset, [](float a, float b) { return a + b; }); 3101464b9930de871fd11870941963253670f737c23A. Unique TensorFlower} 3111464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 3121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array4D<float>> 3131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsReferenceUtil::SelectAndScatter4DGePlus( 3141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const Array4D<float>& operand, const Array4D<float>& source, float init, 3151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const tensorflow::gtl::ArraySlice<int64>& window, 3161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const tensorflow::gtl::ArraySlice<int64>& stride, bool same_padding) { 3171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins Padding padding = same_padding ? Padding::kSame : Padding::kValid; 3181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto result = MakeUnique<Array4D<float>>(operand.n1(), operand.n2(), 3191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins operand.n3(), operand.n4()); 3201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(), 3211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins operand.n4()}; 3221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding); 3231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // Fill the output, with the initial value. 3241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins result->Fill(init); 3251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 3261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::vector<int64> window_counts(window.size(), 0); 3271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::vector<int64> pad_low(window.size(), 0); 3281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i = 0; i < window.size(); ++i) { 3291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins window_counts[i] = 3301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins WindowCount(dim_lengths[i], window[i], stride[i], padding); 3311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins pad_low[i] = padding_both[i].first; 3321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 3331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins CHECK_EQ(window_counts[0], source.n1()); 3341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins CHECK_EQ(window_counts[1], source.n2()); 3351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins CHECK_EQ(window_counts[2], source.n3()); 3361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins CHECK_EQ(window_counts[3], source.n4()); 3371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 3381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // Do a full 4D select and Scatter. 3391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i0 = 0; i0 < window_counts[0]; ++i0) { 3401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i1 = 0; i1 < window_counts[1]; ++i1) { 3411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i2 = 0; i2 < window_counts[2]; ++i2) { 3421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i3 = 0; i3 < window_counts[3]; ++i3) { 3431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // Now we are inside a window and need to find the max and the argmax. 3441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 i0_base = i0 * stride[0] - pad_low[0]; 3451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 i1_base = i1 * stride[1] - pad_low[1]; 3461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 i2_base = i2 * stride[2] - pad_low[2]; 3471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 i3_base = i3 * stride[3] - pad_low[3]; 3481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 scatter_0 = (i0_base >= 0) ? i0_base : 0; 3491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 scatter_1 = (i1_base >= 0) ? i1_base : 0; 3501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 scatter_2 = (i2_base >= 0) ? i2_base : 0; 3511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 scatter_3 = (i3_base >= 0) ? i3_base : 0; 3521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins float val = operand(scatter_0, scatter_1, scatter_2, scatter_3); 3531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) { 3541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i1_win = 0; i1_win < window[1]; ++i1_win) { 3551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i2_win = 0; i2_win < window[2]; ++i2_win) { 3561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i3_win = 0; i3_win < window[3]; ++i3_win) { 3571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 && 3581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins i2_base + i2_win >= 0 && i3_base + i3_win >= 0 && 3591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins i0_base + i0_win < operand.n1() && 3601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins i1_base + i1_win < operand.n2() && 3611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins i2_base + i2_win < operand.n3() && 3621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins i3_base + i3_win < operand.n4()) { 3631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins float tmp = operand(i0_base + i0_win, i1_base + i1_win, 3641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins i2_base + i2_win, i3_base + i3_win); 3651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins if (tmp >= val) { 3661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins val = tmp; 3671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins scatter_0 = i0_base + i0_win; 3681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins scatter_1 = i1_base + i1_win; 3691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins scatter_2 = i2_base + i2_win; 3701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins scatter_3 = i3_base + i3_win; 3711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 3721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 3731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 3741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 3751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 3761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 3771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins (*result)(scatter_0, scatter_1, scatter_2, scatter_3) += 3781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins source(i0, i1, i2, i3); 3791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 3801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 3811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 3821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 3831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return result; 3841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 3851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 3861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array4D<float>> 3871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsReferenceUtil::ConvArray4DGeneralDimensions( 3881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const Array4D<float>& lhs, const Array4D<float>& rhs, 3891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::pair<int64, int64> kernel_stride, Padding padding, 3901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ConvolutionDimensionNumbers dimension_numbers) { 3911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return ConvArray4DGeneralDimensionsDilated(lhs, rhs, kernel_stride, padding, 392f4b8d21b8e41636b6e61f0a1de753430108d2ee7A. Unique TensorFlower {1, 1}, {1, 1}, 393f4b8d21b8e41636b6e61f0a1de753430108d2ee7A. Unique TensorFlower std::move(dimension_numbers)); 3941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 3951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 3961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array4D<float>> 3971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsReferenceUtil::ConvArray4DGeneralDimensionsDilated( 3981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const Array4D<float>& lhs, const Array4D<float>& rhs, 3991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::pair<int64, int64> kernel_stride, Padding padding, 4001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::pair<int64, int64> lhs_dilation, std::pair<int64, int64> rhs_dilation, 4011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ConvolutionDimensionNumbers dnums) { 4021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::array<int64, 4> lhs_dimensions{{lhs.n1(), lhs.n2(), lhs.n3(), lhs.n4()}}; 4031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::array<int64, 4> rhs_dimensions{{rhs.n1(), rhs.n2(), rhs.n3(), rhs.n4()}}; 4041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 4051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const int64 ksy = kernel_stride.first; 4061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const int64 ksx = kernel_stride.second; 4071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const int64 dy = lhs_dilation.first; 4081e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const int64 dx = lhs_dilation.second; 4091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const int64 dky = rhs_dilation.first; 4101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const int64 dkx = rhs_dilation.second; 4111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins CHECK_GE(dky, 1); 4121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins CHECK_GE(dkx, 1); 4131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins CHECK_GE(dy, 1); 4141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins CHECK_GE(dx, 1); 4151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 4161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // Get all dimension sizes in lhs and rhs based on the given convolution 4171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // dimension configuration. 4181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const int64 ix = window_util::DilatedBound( 4191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins lhs_dimensions[dnums.spatial_dimensions(1)], dx); 4201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const int64 iy = window_util::DilatedBound( 4211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins lhs_dimensions[dnums.spatial_dimensions(0)], dy); 4221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const int64 iz = lhs_dimensions[dnums.feature_dimension()]; 4231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const int64 samples = lhs_dimensions[dnums.batch_dimension()]; 4241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const int64 kx = window_util::DilatedBound( 4251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins rhs_dimensions[dnums.kernel_spatial_dimensions(1)], dkx); 4261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const int64 ky = window_util::DilatedBound( 4271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins rhs_dimensions[dnums.kernel_spatial_dimensions(0)], dky); 4281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const int64 oz = rhs_dimensions[dnums.kernel_output_feature_dimension()]; 4291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins { 4301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const int64 kiz = rhs_dimensions[dnums.kernel_input_feature_dimension()]; 4311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins CHECK_EQ(kiz, iz); 4321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 4331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 4341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins if (padding == Padding::kSame) { 4351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // We reject same padding with kernel striding, since it's somewhat 4361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // nonsensical. We can always follow up to implement this with the desired 4371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // semantics if anybody actually uses it. 4381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins CHECK_EQ(1, ksy); 4391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins CHECK_EQ(1, ksx); 4401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 4411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 4421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const int64 ox = 4431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins padding == Padding::kSame ? ix : window_util::StridedBound(ix, kx, ksx); 4441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const int64 oy = 4451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins padding == Padding::kSame ? iy : window_util::StridedBound(iy, ky, ksy); 4461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const int64 istartx = 4471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins padding == Padding::kValid ? 0 : kx % 2 == 0 ? -(kx / 2 - 1) : -kx / 2; 4481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const int64 istarty = 4491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins padding == Padding::kValid ? 0 : ky % 2 == 0 ? -(ky / 2 - 1) : -ky / 2; 4501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // Create the output result array and reset the values to 0. 4511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::array<int64, 4> result_dimensions; 4521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins result_dimensions[dnums.batch_dimension()] = samples; 4531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins result_dimensions[dnums.feature_dimension()] = oz; 4541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins result_dimensions[dnums.spatial_dimensions(0)] = oy; 4551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins result_dimensions[dnums.spatial_dimensions(1)] = ox; 4561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto result = 4571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins MakeUnique<Array4D<float>>(result_dimensions[0], result_dimensions[1], 4581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins result_dimensions[2], result_dimensions[3]); 4591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins result->Fill(0.0); 4601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 4617135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar const auto is_int32 = [](int64 x) { 4627135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar return x >= std::numeric_limits<int32>::min() && 4637135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar x <= std::numeric_limits<int32>::max(); 4647135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar }; 4657135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar 4667135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar // 64-bit idiv/mod are much more expensive x86-64 than 32-bit idiv/imod (at 4677135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar // least on x86-64), so we avoid them where possible. 4687135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar const auto fast_idiv64 = [&](int64 a, int64 b) { 4697135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar if (is_int32(a) && is_int32(b)) { 4707135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar return static_cast<int64>(static_cast<int32>(a) / static_cast<int32>(b)); 4717135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar } 4727135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar return a / b; 4737135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar }; 4747135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar const auto fast_imod64 = [&](int64 a, int64 b) { 4757135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar if (is_int32(a) && is_int32(b)) { 4767135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar return static_cast<int64>(static_cast<int32>(a) % static_cast<int32>(b)); 4777135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar } 4787135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar return a % b; 4797135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar }; 4807135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar 4811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // Lambda to access the lhs operand at the given 4D index. 4821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const auto lhs_element = [&](int64 batch, int64 feature, int64 height, 4831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 width) { 4847135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar if (fast_imod64(height, dy) != 0 || fast_imod64(width, dx) != 0) { 4851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return 0.0f; 4861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 4871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 4881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::array<int64, 4> index; 4891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins index[dnums.batch_dimension()] = batch; 4901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins index[dnums.feature_dimension()] = feature; 4917135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar index[dnums.spatial_dimensions(0)] = fast_idiv64(height, dy); 4927135d08d4e6067865d7b5f2907013c960a12ae4fJustin Lebar index[dnums.spatial_dimensions(1)] = fast_idiv64(width, dx); 4931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return lhs(index[0], index[1], index[2], index[3]); 4941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins }; 4951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 496f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar // Lambda to access the rhs operand at the given 4D index. height_over_dky 497f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar // should be equal to height / dky, and width_over_dkx should be equal to 498f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar // width / dkx. (This is an optimization to avoid doing divisions.) 4990034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi const auto rhs_element = 5000034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi [&](int64 kernel_output_feature, int64 kernel_input_feature, int64 height, 5010034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi int64 width, int64 height_over_dky, int64 width_over_dkx) { 5020034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi DCHECK_EQ(height % dky, 0); 5030034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi DCHECK_EQ(width % dkx, 0); 5040034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi DCHECK_EQ(height / dky, height_over_dky); 5050034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi DCHECK_EQ(width / dkx, width_over_dkx); 5060034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi 5070034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi std::array<int64, 4> index; 5080034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi index[dnums.kernel_output_feature_dimension()] = kernel_output_feature; 5090034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi index[dnums.kernel_input_feature_dimension()] = kernel_input_feature; 5100034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi index[dnums.kernel_spatial_dimensions(0)] = height_over_dky; 5110034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi index[dnums.kernel_spatial_dimensions(1)] = width_over_dkx; 5120034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi return rhs(index[0], index[1], index[2], index[3]); 5130034029ac66bb60f272fa1aae05eca5dd9d210d1Tayo Oguntebi }; 5141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 5151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // Lambda to access the result data at the given 4D index. 5161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const auto result_element = [&](int64 batch, int64 kernel_output_feature, 5171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 height, int64 width) -> float& { 5181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::array<int64, 4> index; 5191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins index[dnums.batch_dimension()] = batch; 5201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins index[dnums.feature_dimension()] = kernel_output_feature; 5211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins index[dnums.spatial_dimensions(0)] = height; 5221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins index[dnums.spatial_dimensions(1)] = width; 5231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return (*result)(index[0], index[1], index[2], index[3]); 5241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins }; 5251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 5261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 oyi = 0; oyi < oy; ++oyi) { 5271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 oxi = 0; oxi < ox; ++oxi) { 5281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 sample = 0; sample < samples; ++sample) { 5291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 izi = 0; izi < iz; ++izi) { 5301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 ozi = 0; ozi < oz; ++ozi) { 531f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar for (int64 kyi = 0, kyi_over_dky = 0; kyi < ky; 532f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar kyi += dky, kyi_over_dky++) { 533f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar for (int64 kxi = 0, kxi_over_dkx = 0; kxi < kx; 534f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar kxi += dkx, kxi_over_dkx++) { 5351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 iyi = istarty + ksy * oyi + kyi; 5361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 ixi = istartx + ksx * oxi + kxi; 5371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins float input = (iyi >= iy || ixi >= ix || iyi < 0 || ixi < 0) 5381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ? 0.0 5391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins : lhs_element(sample, izi, iyi, ixi); 540f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar float gain = 541f89185b655d380074c3e1e932e90d80cd2b01241Justin Lebar rhs_element(ozi, izi, kyi, kxi, kyi_over_dky, kxi_over_dkx); 5421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins float addend = input * gain; 5431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins result_element(sample, ozi, oyi, oxi) += addend; 5441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 5451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 5461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 5471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 5481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 5491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 5501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 5511c57e8864aeb8093f12e8f390571fb8cd9f376a4A. Unique TensorFlower if (samples == 0 || kx == 0 || ky == 0 || ox == 0 || oy == 0 || oz == 0 || 5521c57e8864aeb8093f12e8f390571fb8cd9f376a4A. Unique TensorFlower iz == 0) { 5531c57e8864aeb8093f12e8f390571fb8cd9f376a4A. Unique TensorFlower LOG(INFO) << "Output will be trivially empty because one of these " 5541c57e8864aeb8093f12e8f390571fb8cd9f376a4A. Unique TensorFlower "dimensions is 0: samples: " 5551c57e8864aeb8093f12e8f390571fb8cd9f376a4A. Unique TensorFlower << samples << " kx: " << kx << " ky: " << ky << " ox: " << ox 5561c57e8864aeb8093f12e8f390571fb8cd9f376a4A. Unique TensorFlower << " oy: " << oy << " oz: " << oz << " iz: " << iz; 5571c57e8864aeb8093f12e8f390571fb8cd9f376a4A. Unique TensorFlower return result; 5581c57e8864aeb8093f12e8f390571fb8cd9f376a4A. Unique TensorFlower } 5591c57e8864aeb8093f12e8f390571fb8cd9f376a4A. Unique TensorFlower bool trivial = true; 5601c57e8864aeb8093f12e8f390571fb8cd9f376a4A. Unique TensorFlower auto check_trivial = [&trivial](tensorflow::gtl::ArraySlice<int64> indices, 5611c57e8864aeb8093f12e8f390571fb8cd9f376a4A. Unique TensorFlower float value) { 5621c57e8864aeb8093f12e8f390571fb8cd9f376a4A. Unique TensorFlower if (value != 0.0) { 5631c57e8864aeb8093f12e8f390571fb8cd9f376a4A. Unique TensorFlower trivial = false; 5641c57e8864aeb8093f12e8f390571fb8cd9f376a4A. Unique TensorFlower } 5651c57e8864aeb8093f12e8f390571fb8cd9f376a4A. Unique TensorFlower }; 5661c57e8864aeb8093f12e8f390571fb8cd9f376a4A. Unique TensorFlower lhs.Each(check_trivial); 567beb291b3c0e95f1150c446e06f8f80675ab3b528Peter Hawkins if (trivial) { 568beb291b3c0e95f1150c446e06f8f80675ab3b528Peter Hawkins LOG(FATAL) << "LHS is all 0.0."; 569beb291b3c0e95f1150c446e06f8f80675ab3b528Peter Hawkins } 5701c57e8864aeb8093f12e8f390571fb8cd9f376a4A. Unique TensorFlower trivial = true; 5711c57e8864aeb8093f12e8f390571fb8cd9f376a4A. Unique TensorFlower rhs.Each(check_trivial); 572beb291b3c0e95f1150c446e06f8f80675ab3b528Peter Hawkins if (trivial) { 573beb291b3c0e95f1150c446e06f8f80675ab3b528Peter Hawkins LOG(FATAL) << "RHS is all 0.0."; 574beb291b3c0e95f1150c446e06f8f80675ab3b528Peter Hawkins } 5751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return result; 5761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 5771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 5781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<std::vector<float>> 5791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsReferenceUtil::ReduceToColArray2D( 5801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const Array2D<float>& matrix, float init, 5811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::function<float(float, float)> reduce_function) { 5821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 rows = matrix.height(); 5831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 cols = matrix.width(); 5841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto result = MakeUnique<std::vector<float>>(); 5851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i = 0; i < rows; ++i) { 5861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins float acc = init; 5871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 j = 0; j < cols; ++j) { 5881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins acc = reduce_function(acc, matrix(i, j)); 5891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 5901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins result->push_back(acc); 5911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 5921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return result; 5931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 5941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 5951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<std::vector<float>> 5961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsReferenceUtil::ReduceToRowArray2D( 5971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const Array2D<float>& matrix, float init, 5981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::function<float(float, float)> reduce_function) { 5991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 rows = matrix.height(); 6001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 cols = matrix.width(); 6011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto result = MakeUnique<std::vector<float>>(); 6021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i = 0; i < cols; ++i) { 6031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins float acc = init; 6041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 j = 0; j < rows; ++j) { 6051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins acc = reduce_function(acc, matrix(j, i)); 6061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 6071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins result->push_back(acc); 6081e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 6091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return result; 6101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 6111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 6121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/*static*/ std::vector<float> ReferenceUtil::Reduce4DTo1D( 6131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const Array4D<float>& array, float init, 6141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins tensorflow::gtl::ArraySlice<int64> dims, 6151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::function<float(float, float)> reduce_function) { 6161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::vector<float> result; 6171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins CHECK_EQ(dims.size(), 3); 6181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const std::set<int64> dim_set(dims.begin(), dims.end()); 6191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins CHECK_EQ(dim_set.size(), 3); 6201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 a0 = 0; a0 == 0 || (!dim_set.count(0) && a0 < array.n1()); ++a0) { 6211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 a1 = 0; a1 == 0 || (!dim_set.count(1) && a1 < array.n2()); 6221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ++a1) { 6231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 a2 = 0; a2 == 0 || (!dim_set.count(2) && a2 < array.n3()); 6241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ++a2) { 6251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 a3 = 0; a3 == 0 || (!dim_set.count(3) && a3 < array.n4()); 6261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ++a3) { 6271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins float accumulator = init; 6281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i0 = 0; i0 == 0 || (dim_set.count(0) && i0 < array.n1()); 6291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ++i0) { 6301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i1 = 0; i1 == 0 || (dim_set.count(1) && i1 < array.n2()); 6311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ++i1) { 6321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i2 = 0; 6331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins i2 == 0 || (dim_set.count(2) && i2 < array.n3()); ++i2) { 6341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i3 = 0; 6351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins i3 == 0 || (dim_set.count(3) && i3 < array.n4()); ++i3) { 6361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins accumulator = reduce_function( 6371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins accumulator, array(a0 + i0, a1 + i1, a2 + i2, a3 + i3)); 6381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 6391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 6401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 6411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 6421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins result.push_back(accumulator); 6431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 6441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 6451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 6461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 6471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return result; 6481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 6491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 6501464b9930de871fd11870941963253670f737c23A. Unique TensorFlower/* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::Broadcast1DTo4D( 6511464b9930de871fd11870941963253670f737c23A. Unique TensorFlower const std::vector<float>& array, const std::vector<int64>& bounds, 6521464b9930de871fd11870941963253670f737c23A. Unique TensorFlower int64 broadcast_from_dim) { 6531464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto result = 6541464b9930de871fd11870941963253670f737c23A. Unique TensorFlower MakeUnique<Array4D<float>>(bounds[0], bounds[1], bounds[2], bounds[3]); 6551464b9930de871fd11870941963253670f737c23A. Unique TensorFlower for (int64 i = 0; i < result->n1(); ++i) { 6561464b9930de871fd11870941963253670f737c23A. Unique TensorFlower for (int64 j = 0; j < result->n2(); ++j) { 6571464b9930de871fd11870941963253670f737c23A. Unique TensorFlower for (int64 k = 0; k < result->n3(); ++k) { 6581464b9930de871fd11870941963253670f737c23A. Unique TensorFlower for (int64 l = 0; l < result->n4(); ++l) { 6591464b9930de871fd11870941963253670f737c23A. Unique TensorFlower switch (broadcast_from_dim) { 6601464b9930de871fd11870941963253670f737c23A. Unique TensorFlower case 0: 6611464b9930de871fd11870941963253670f737c23A. Unique TensorFlower (*result)(i, j, k, l) = array[i]; 6621464b9930de871fd11870941963253670f737c23A. Unique TensorFlower break; 6631464b9930de871fd11870941963253670f737c23A. Unique TensorFlower case 1: 6641464b9930de871fd11870941963253670f737c23A. Unique TensorFlower (*result)(i, j, k, l) = array[j]; 6651464b9930de871fd11870941963253670f737c23A. Unique TensorFlower break; 6661464b9930de871fd11870941963253670f737c23A. Unique TensorFlower case 2: 6671464b9930de871fd11870941963253670f737c23A. Unique TensorFlower (*result)(i, j, k, l) = array[k]; 6681464b9930de871fd11870941963253670f737c23A. Unique TensorFlower break; 6691464b9930de871fd11870941963253670f737c23A. Unique TensorFlower case 3: 6701464b9930de871fd11870941963253670f737c23A. Unique TensorFlower (*result)(i, j, k, l) = array[l]; 6711464b9930de871fd11870941963253670f737c23A. Unique TensorFlower break; 6721464b9930de871fd11870941963253670f737c23A. Unique TensorFlower default: 6731464b9930de871fd11870941963253670f737c23A. Unique TensorFlower break; 6741464b9930de871fd11870941963253670f737c23A. Unique TensorFlower } 6751464b9930de871fd11870941963253670f737c23A. Unique TensorFlower } 6761464b9930de871fd11870941963253670f737c23A. Unique TensorFlower } 6771464b9930de871fd11870941963253670f737c23A. Unique TensorFlower } 6781464b9930de871fd11870941963253670f737c23A. Unique TensorFlower } 6791464b9930de871fd11870941963253670f737c23A. Unique TensorFlower return result; 6801464b9930de871fd11870941963253670f737c23A. Unique TensorFlower} 6811464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 6821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::Reduce3DTo2D( 6831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const Array3D<float>& array, float init, 6841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins tensorflow::gtl::ArraySlice<int64> dims, 6851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::function<float(float, float)> reduce_function) { 6861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins CHECK_EQ(dims.size(), 1); 6871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 rows = dims[0] == 0 ? array.n2() : array.n1(); 6881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 cols = dims[0] == 2 ? array.n2() : array.n3(); 6891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto result = MakeUnique<Array2D<float>>(rows, cols); 6901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins result->Fill(init); 6911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int i0 = 0; i0 < array.n1(); ++i0) { 6921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int i1 = 0; i1 < array.n2(); ++i1) { 6931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int i2 = 0; i2 < array.n3(); ++i2) { 6941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 row = dims[0] == 0 ? i1 : i0; 6951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 col = dims[0] == 2 ? i1 : i2; 6961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins (*result)(row, col) = 6971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins reduce_function((*result)(row, col), array(i0, i1, i2)); 6981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 6991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 7001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 7011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return result; 7021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 7031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 7041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MapArray2D( 7051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const Array2D<float>& matrix, 7061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const std::function<float(float)>& map_function) { 7071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 rows = matrix.height(); 7081e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 cols = matrix.width(); 7091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto result = MakeUnique<Array2D<float>>(rows, cols); 7101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i = 0; i < rows; ++i) { 7111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 j = 0; j < cols; ++j) { 7121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins (*result)(i, j) = map_function(matrix(i, j)); 7131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 7141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 7151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return result; 7161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 7171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 7181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MapArray2D( 7191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const Array2D<float>& lhs, const Array2D<float>& rhs, 7201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const std::function<float(float, float)>& map_function) { 7211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins CHECK_EQ(lhs.height(), rhs.height()); 7221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins CHECK_EQ(lhs.width(), rhs.width()); 7231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 rows = lhs.height(); 7241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 cols = rhs.width(); 7251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto result = MakeUnique<Array2D<float>>(rows, cols); 7261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i = 0; i < rows; ++i) { 7271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 j = 0; j < cols; ++j) { 7281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins (*result)(i, j) = map_function(lhs(i, j), rhs(i, j)); 7291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 7301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 7311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return result; 7321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 7331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 7341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MapWithIndexArray2D( 7351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const Array2D<float>& matrix, 7361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const std::function<float(float, int64, int64)>& map_function) { 7371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 rows = matrix.height(); 7381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 cols = matrix.width(); 7391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto result = MakeUnique<Array2D<float>>(rows, cols); 7401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 i = 0; i < rows; ++i) { 7411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int64 j = 0; j < cols; ++j) { 7421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins (*result)(i, j) = map_function(matrix(i, j), i, j); 7431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 7441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 7451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return result; 7461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 7471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 7481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::PadArray2D( 7491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const Array2D<float>& operand, const PaddingConfig& padding, 7501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const float pad) { 7511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 in0 = operand.n1(); 7521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 high_padding0 = padding.dimensions(0).edge_padding_high(); 7531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 low_padding0 = padding.dimensions(0).edge_padding_low(); 7541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 interior_padding0 = padding.dimensions(0).interior_padding(); 7551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 out0 = 7561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins in0 + low_padding0 + high_padding0 + (in0 - 1) * interior_padding0; 7575aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan 7581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 in1 = operand.n2(); 7591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 high_padding1 = padding.dimensions(1).edge_padding_high(); 7601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 low_padding1 = padding.dimensions(1).edge_padding_low(); 7611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 interior_padding1 = padding.dimensions(1).interior_padding(); 7621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 out1 = 7631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins in1 + low_padding1 + high_padding1 + (in1 - 1) * interior_padding1; 7645aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan 7651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto result = MakeUnique<Array2D<float>>(out0, out1); 7661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins result->Fill(pad); 7675aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan int64 o0 = low_padding0; 7685aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan for (int64 i0 = 0; i0 < in0; ++i0) { 7695aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan int64 o1 = low_padding1; 7705aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan for (int64 i1 = 0; i1 < in1; ++i1) { 7715aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan if (o0 >= 0 && o1 >= 0 && o0 < out0 && o1 < out1) { 7725aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan (*result)(o0, o1) = operand(i0, i1); 7735aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan } 7745aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan o1 += interior_padding1 + 1; 7751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 7765aed36230aaa8e727990f1b8e37c82bed8b5356fMark Heffernan o0 += interior_padding0 + 1; 7771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 7781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return result; 7791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 7801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 781c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune/* static */ Array4D<float> ReferenceUtil::PadArray4D( 782c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune const Array4D<float>& operand, const PaddingConfig& padding, 783c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune const float pad) { 784c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune CHECK_EQ(padding.dimensions_size(), 4); 785c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune 786c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune const std::vector<int64> input_bounds = {operand.n1(), operand.n2(), 787c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune operand.n3(), operand.n4()}; 788c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune std::vector<int64> pad_low(4); 789c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune std::vector<int64> pad_high(4); 7908ac9eb6318084fcb3d9dc71a7a7fc05bf18048d8A. Unique TensorFlower std::vector<int64> pad_interior(4); 791c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune std::vector<int64> output_bounds(4); 792c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune for (int64 i = 0; i < 4; ++i) { 793c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune pad_low[i] = padding.dimensions(i).edge_padding_low(); 794c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune pad_high[i] = padding.dimensions(i).edge_padding_high(); 7958ac9eb6318084fcb3d9dc71a7a7fc05bf18048d8A. Unique TensorFlower CHECK_LE(0, padding.dimensions(i).interior_padding()) << "not implemented"; 7968ac9eb6318084fcb3d9dc71a7a7fc05bf18048d8A. Unique TensorFlower pad_interior[i] = padding.dimensions(i).interior_padding(); 797c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune 7988ac9eb6318084fcb3d9dc71a7a7fc05bf18048d8A. Unique TensorFlower output_bounds[i] = pad_low[i] + input_bounds[i] + pad_high[i] + 7998ac9eb6318084fcb3d9dc71a7a7fc05bf18048d8A. Unique TensorFlower (input_bounds[i] - 1) * pad_interior[i]; 800c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune } 801c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune 802c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune Array4D<float> result(output_bounds[0], output_bounds[1], output_bounds[2], 803c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune output_bounds[3]); 804c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune result.Each([&](tensorflow::gtl::ArraySlice<int64> indices, float* value) { 805c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune for (int i = 0; i < 4; ++i) { 806c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune bool in_low_padding = indices[i] < pad_low[i]; 807c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune bool in_high_padding = indices[i] >= output_bounds[i] - pad_high[i]; 808c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune if (in_low_padding || in_high_padding) { 809c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune *value = pad; 810c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune return; 811c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune } 8128ac9eb6318084fcb3d9dc71a7a7fc05bf18048d8A. Unique TensorFlower if (pad_interior[i] && 8138ac9eb6318084fcb3d9dc71a7a7fc05bf18048d8A. Unique TensorFlower (indices[i] - pad_low[i]) % (pad_interior[i] + 1)) { 8148ac9eb6318084fcb3d9dc71a7a7fc05bf18048d8A. Unique TensorFlower *value = pad; 8158ac9eb6318084fcb3d9dc71a7a7fc05bf18048d8A. Unique TensorFlower return; 8168ac9eb6318084fcb3d9dc71a7a7fc05bf18048d8A. Unique TensorFlower } 817c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune } 8188ac9eb6318084fcb3d9dc71a7a7fc05bf18048d8A. Unique TensorFlower *value = operand((indices[0] - pad_low[0]) / (pad_interior[0] + 1), 8198ac9eb6318084fcb3d9dc71a7a7fc05bf18048d8A. Unique TensorFlower (indices[1] - pad_low[1]) / (pad_interior[1] + 1), 8208ac9eb6318084fcb3d9dc71a7a7fc05bf18048d8A. Unique TensorFlower (indices[2] - pad_low[2]) / (pad_interior[2] + 1), 8218ac9eb6318084fcb3d9dc71a7a7fc05bf18048d8A. Unique TensorFlower (indices[3] - pad_low[3]) / (pad_interior[3] + 1)); 822c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune }); 823c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune return result; 824c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune} 825c1bd0fe248c63b58b0b663a8c8529791354fdf75Bjarke Hammersholt Roune 8261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} // namespace xla 827