11e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 21e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 31e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsLicensed under the Apache License, Version 2.0 (the "License"); 41e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsyou may not use this file except in compliance with the License. 51e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsYou may obtain a copy of the License at 61e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 71e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins http://www.apache.org/licenses/LICENSE-2.0 81e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 91e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsUnless required by applicable law or agreed to in writing, software 101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsdistributed under the License is distributed on an "AS IS" BASIS, 111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsSee the License for the specific language governing permissions and 131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinslimitations under the License. 141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins==============================================================================*/ 151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// XLA-specific Transpose Op. This is very different to the Eigen 171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// version in third_party/tensorflow because XLA's reshape neatly 181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// handles all transposes, while Eigen needs a restricted DoTranspose 191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// helper. 201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/kernels/transpose_op.h" 221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/tf2xla/type_util.h" 231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/tf2xla/xla_helpers.h" 241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" 25a8c325e57c1077f1e8df540a20bd8b36d3d1f968Peter Hawkins#include "tensorflow/compiler/tf2xla/xla_op_registry.h" 261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/framework/kernel_def_builder.h" 271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/framework/register_types.h" 281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/kernels/bounds_check.h" 291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsnamespace tensorflow { 311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsnamespace { 321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsclass TransposeOp : public XlaOpKernel { 341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins public: 351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins explicit TransposeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} 361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins void Compile(XlaOpKernelContext* ctx) override { 381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const TensorShape input_shape = ctx->InputShape(0); 391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const TensorShape perm_tensor_shape = ctx->InputShape(1); 401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // Preliminary validation of sizes. 421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins OP_REQUIRES(ctx, TensorShapeUtils::IsVector(perm_tensor_shape), 431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins errors::InvalidArgument("perm must be a vector, not ", 441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins perm_tensor_shape.DebugString())); 451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const int dims = input_shape.dims(); 471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins OP_REQUIRES(ctx, dims == perm_tensor_shape.num_elements(), 481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins errors::InvalidArgument("transpose expects a vector of size ", 491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins input_shape.dims(), 501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ". But input(1) is a vector of size ", 511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins perm_tensor_shape.num_elements())); 521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins xla::Literal literal; 541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins OP_REQUIRES_OK(ctx, ctx->ConstantInputReshaped(1, {dims}, &literal)); 551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::vector<int32> perm(dims); 577d64e124103c8334b7d8b127cd2eff786959d185Mark Heffernan std::copy(literal.data<int32>().begin(), literal.data<int32>().end(), 587d64e124103c8334b7d8b127cd2eff786959d185Mark Heffernan perm.begin()); 591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::vector<int64> transposed_order; 611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // Check whether permutation is a permutation of integers of [0 .. dims). 621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins gtl::InlinedVector<bool, 8> bits(dims); 631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins bool is_identity = true; 641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int i = 0; i < dims; ++i) { 651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const int32 d = perm[i]; 661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins OP_REQUIRES( 671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ctx, 0 <= d && d < dims, 681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins errors::InvalidArgument(d, " is out of range [0 .. ", dims, ")")); 691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins bits[d] = true; 701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins transposed_order.push_back(d); 711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins if (d != i) { 721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins is_identity = false; 731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int i = 0; i < dims; ++i) { 76c975bc2b3fdc9674dd71a7ed89c74ce8ea2d46f0Peter Hawkins OP_REQUIRES( 77c975bc2b3fdc9674dd71a7ed89c74ce8ea2d46f0Peter Hawkins ctx, bits[i], 78c975bc2b3fdc9674dd71a7ed89c74ce8ea2d46f0Peter Hawkins errors::InvalidArgument(i, " is missing from 'perm' argument.")); 791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // 0-D, 1-D, and identity transposes do nothing. 821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins if (dims <= 1 || is_identity) { 831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ctx->SetOutput(0, ctx->Input(0)); 841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return; 851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ctx->SetOutput(0, 881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ctx->builder()->Transpose(ctx->Input(0), transposed_order)); 891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}; 911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 92c975bc2b3fdc9674dd71a7ed89c74ce8ea2d46f0Peter HawkinsREGISTER_XLA_OP(Name("Transpose").CompileTimeConstInput("perm"), TransposeOp); 931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// InvertPermutation frequently forms part of the gradient of Transpose. 951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// 961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// inv = InvertPermutationOp(T<int32> p) takes a permutation of 971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// integers 0, 1, ..., n - 1 and returns the inverted 981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// permutation of p. I.e., inv[p[i]] == i, for i in [0 .. n). 991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// 1001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// REQUIRES: input is a vector of int32. 1011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// REQUIRES: input is a permutation of 0, 1, ..., n-1. 1021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsclass InvertPermutationOp : public XlaOpKernel { 1041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins public: 1051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins explicit InvertPermutationOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} 1061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins void Compile(XlaOpKernelContext* ctx) override { 108c975bc2b3fdc9674dd71a7ed89c74ce8ea2d46f0Peter Hawkins OP_REQUIRES(ctx, 109c975bc2b3fdc9674dd71a7ed89c74ce8ea2d46f0Peter Hawkins FastBoundsCheck(ctx->InputShape(0).num_elements(), 110c975bc2b3fdc9674dd71a7ed89c74ce8ea2d46f0Peter Hawkins std::numeric_limits<int32>::max()), 1111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins errors::InvalidArgument("permutation of nonnegative int32s " 1121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins "must have <= int32 max elements")); 1131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::vector<int64> perm; 1151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(0, &perm)); 1161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int size = perm.size(); 1181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::vector<int32> output(size); 1201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::fill_n(output.data(), size, -1); 1211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int i = 0; i < size; ++i) { 1221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const int64 d = perm[i]; 1231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins OP_REQUIRES(ctx, FastBoundsCheck(d, size), 1241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins errors::InvalidArgument(d, " is not between 0 and ", size)); 1251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins OP_REQUIRES(ctx, output[d] == -1, 1261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins errors::InvalidArgument(d, " is duplicated in the input.")); 1271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins output[d] = i; 1281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 1291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ctx->SetOutput(0, ctx->builder()->ConstantR1<int32>(output)); 1311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 1321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}; 1331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 134c975bc2b3fdc9674dd71a7ed89c74ce8ea2d46f0Peter HawkinsREGISTER_XLA_OP(Name("InvertPermutation") 135c975bc2b3fdc9674dd71a7ed89c74ce8ea2d46f0Peter Hawkins .TypeConstraint("T", DT_INT32) 136c975bc2b3fdc9674dd71a7ed89c74ce8ea2d46f0Peter Hawkins .CompileTimeConstInput("x"), 13793f9caba8e371bd2f55ec789ed2f8ece9b3d976dPeter Hawkins InvertPermutationOp); 1381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} // namespace 1401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} // namespace tensorflow 141