11e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
21e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
31e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsLicensed under the Apache License, Version 2.0 (the "License");
41e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsyou may not use this file except in compliance with the License.
51e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsYou may obtain a copy of the License at
61e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
71e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    http://www.apache.org/licenses/LICENSE-2.0
81e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
91e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsUnless required by applicable law or agreed to in writing, software
101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsdistributed under the License is distributed on an "AS IS" BASIS,
111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsSee the License for the specific language governing permissions and
131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinslimitations under the License.
141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins==============================================================================*/
151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// XLA-specific Transpose Op. This is very different to the Eigen
171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// version in third_party/tensorflow because XLA's reshape neatly
181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// handles all transposes, while Eigen needs a restricted DoTranspose
191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// helper.
201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/kernels/transpose_op.h"
221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/tf2xla/type_util.h"
231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/tf2xla/xla_helpers.h"
241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
25a8c325e57c1077f1e8df540a20bd8b36d3d1f968Peter Hawkins#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/framework/kernel_def_builder.h"
271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/framework/register_types.h"
281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/kernels/bounds_check.h"
291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsnamespace tensorflow {
311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsnamespace {
321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsclass TransposeOp : public XlaOpKernel {
341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins public:
351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  explicit TransposeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  void Compile(XlaOpKernelContext* ctx) override {
381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const TensorShape input_shape = ctx->InputShape(0);
391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const TensorShape perm_tensor_shape = ctx->InputShape(1);
401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    // Preliminary validation of sizes.
421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    OP_REQUIRES(ctx, TensorShapeUtils::IsVector(perm_tensor_shape),
431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                errors::InvalidArgument("perm must be a vector, not ",
441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                        perm_tensor_shape.DebugString()));
451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const int dims = input_shape.dims();
471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    OP_REQUIRES(ctx, dims == perm_tensor_shape.num_elements(),
481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                errors::InvalidArgument("transpose expects a vector of size ",
491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                        input_shape.dims(),
501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                        ". But input(1) is a vector of size ",
511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                        perm_tensor_shape.num_elements()));
521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    xla::Literal literal;
541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    OP_REQUIRES_OK(ctx, ctx->ConstantInputReshaped(1, {dims}, &literal));
551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    std::vector<int32> perm(dims);
577d64e124103c8334b7d8b127cd2eff786959d185Mark Heffernan    std::copy(literal.data<int32>().begin(), literal.data<int32>().end(),
587d64e124103c8334b7d8b127cd2eff786959d185Mark Heffernan              perm.begin());
591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    std::vector<int64> transposed_order;
611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    // Check whether permutation is a permutation of integers of [0 .. dims).
621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    gtl::InlinedVector<bool, 8> bits(dims);
631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    bool is_identity = true;
641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    for (int i = 0; i < dims; ++i) {
651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      const int32 d = perm[i];
661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      OP_REQUIRES(
671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          ctx, 0 <= d && d < dims,
681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          errors::InvalidArgument(d, " is out of range [0 .. ", dims, ")"));
691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      bits[d] = true;
701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      transposed_order.push_back(d);
711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      if (d != i) {
721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        is_identity = false;
731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      }
741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    for (int i = 0; i < dims; ++i) {
76c975bc2b3fdc9674dd71a7ed89c74ce8ea2d46f0Peter Hawkins      OP_REQUIRES(
77c975bc2b3fdc9674dd71a7ed89c74ce8ea2d46f0Peter Hawkins          ctx, bits[i],
78c975bc2b3fdc9674dd71a7ed89c74ce8ea2d46f0Peter Hawkins          errors::InvalidArgument(i, " is missing from 'perm' argument."));
791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    // 0-D, 1-D, and identity transposes do nothing.
821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    if (dims <= 1 || is_identity) {
831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ctx->SetOutput(0, ctx->Input(0));
841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      return;
851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    ctx->SetOutput(0,
881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                   ctx->builder()->Transpose(ctx->Input(0), transposed_order));
891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins};
911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
92c975bc2b3fdc9674dd71a7ed89c74ce8ea2d46f0Peter HawkinsREGISTER_XLA_OP(Name("Transpose").CompileTimeConstInput("perm"), TransposeOp);
931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// InvertPermutation frequently forms part of the gradient of Transpose.
951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins//
961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// inv = InvertPermutationOp(T<int32> p) takes a permutation of
971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// integers 0, 1, ..., n - 1 and returns the inverted
981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// permutation of p. I.e., inv[p[i]] == i, for i in [0 .. n).
991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins//
1001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// REQUIRES: input is a vector of int32.
1011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// REQUIRES: input is a permutation of 0, 1, ..., n-1.
1021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsclass InvertPermutationOp : public XlaOpKernel {
1041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins public:
1051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  explicit InvertPermutationOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
1061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  void Compile(XlaOpKernelContext* ctx) override {
108c975bc2b3fdc9674dd71a7ed89c74ce8ea2d46f0Peter Hawkins    OP_REQUIRES(ctx,
109c975bc2b3fdc9674dd71a7ed89c74ce8ea2d46f0Peter Hawkins                FastBoundsCheck(ctx->InputShape(0).num_elements(),
110c975bc2b3fdc9674dd71a7ed89c74ce8ea2d46f0Peter Hawkins                                std::numeric_limits<int32>::max()),
1111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                errors::InvalidArgument("permutation of nonnegative int32s "
1121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                        "must have <= int32 max elements"));
1131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    std::vector<int64> perm;
1151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(0, &perm));
1161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    int size = perm.size();
1181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    std::vector<int32> output(size);
1201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    std::fill_n(output.data(), size, -1);
1211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    for (int i = 0; i < size; ++i) {
1221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      const int64 d = perm[i];
1231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      OP_REQUIRES(ctx, FastBoundsCheck(d, size),
1241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                  errors::InvalidArgument(d, " is not between 0 and ", size));
1251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      OP_REQUIRES(ctx, output[d] == -1,
1261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                  errors::InvalidArgument(d, " is duplicated in the input."));
1271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      output[d] = i;
1281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
1291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    ctx->SetOutput(0, ctx->builder()->ConstantR1<int32>(output));
1311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
1321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins};
1331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
134c975bc2b3fdc9674dd71a7ed89c74ce8ea2d46f0Peter HawkinsREGISTER_XLA_OP(Name("InvertPermutation")
135c975bc2b3fdc9674dd71a7ed89c74ce8ea2d46f0Peter Hawkins                    .TypeConstraint("T", DT_INT32)
136c975bc2b3fdc9674dd71a7ed89c74ce8ea2d46f0Peter Hawkins                    .CompileTimeConstInput("x"),
13793f9caba8e371bd2f55ec789ed2f8ece9b3d976dPeter Hawkins                InvertPermutationOp);
1381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}  // namespace
1401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}  // namespace tensorflow
141