transpose_op.cc revision c975bc2b3fdc9674dd71a7ed89c74ce8ea2d46f0
11e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
21e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
31e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsLicensed under the Apache License, Version 2.0 (the "License");
41e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsyou may not use this file except in compliance with the License.
51e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsYou may obtain a copy of the License at
61e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
71e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    http://www.apache.org/licenses/LICENSE-2.0
81e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
91e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsUnless required by applicable law or agreed to in writing, software
101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsdistributed under the License is distributed on an "AS IS" BASIS,
111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsSee the License for the specific language governing permissions and
131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinslimitations under the License.
141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins==============================================================================*/
151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// XLA-specific Transpose Op. This is very different to the Eigen
171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// version in third_party/tensorflow because XLA's reshape neatly
181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// handles all transposes, while Eigen needs a restricted DoTranspose
191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// helper.
201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/kernels/transpose_op.h"
221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/tf2xla/type_util.h"
231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/tf2xla/xla_helpers.h"
241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
25a8c325e57c1077f1e8df540a20bd8b36d3d1f968Peter Hawkins#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/framework/kernel_def_builder.h"
271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/framework/register_types.h"
281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/kernels/bounds_check.h"
291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsnamespace tensorflow {
311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsnamespace {
321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsclass TransposeOp : public XlaOpKernel {
341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins public:
351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  explicit TransposeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  void Compile(XlaOpKernelContext* ctx) override {
381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const TensorShape input_shape = ctx->InputShape(0);
391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const TensorShape perm_tensor_shape = ctx->InputShape(1);
401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    // Preliminary validation of sizes.
421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    OP_REQUIRES(ctx, TensorShapeUtils::IsVector(perm_tensor_shape),
431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                errors::InvalidArgument("perm must be a vector, not ",
441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                        perm_tensor_shape.DebugString()));
451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const int dims = input_shape.dims();
471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    OP_REQUIRES(ctx, dims == perm_tensor_shape.num_elements(),
481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                errors::InvalidArgument("transpose expects a vector of size ",
491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                        input_shape.dims(),
501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                        ". But input(1) is a vector of size ",
511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                        perm_tensor_shape.num_elements()));
521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    xla::Literal literal;
541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    OP_REQUIRES_OK(ctx, ctx->ConstantInputReshaped(1, {dims}, &literal));
551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    std::vector<int32> perm(dims);
571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    std::copy(literal.s32s().begin(), literal.s32s().end(), perm.begin());
581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    std::vector<int64> transposed_order;
601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    // Check whether permutation is a permutation of integers of [0 .. dims).
611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    gtl::InlinedVector<bool, 8> bits(dims);
621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    bool is_identity = true;
631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    for (int i = 0; i < dims; ++i) {
641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      const int32 d = perm[i];
651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      OP_REQUIRES(
661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          ctx, 0 <= d && d < dims,
671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          errors::InvalidArgument(d, " is out of range [0 .. ", dims, ")"));
681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      bits[d] = true;
691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      transposed_order.push_back(d);
701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      if (d != i) {
711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        is_identity = false;
721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      }
731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    for (int i = 0; i < dims; ++i) {
75c975bc2b3fdc9674dd71a7ed89c74ce8ea2d46f0Peter Hawkins      OP_REQUIRES(
76c975bc2b3fdc9674dd71a7ed89c74ce8ea2d46f0Peter Hawkins          ctx, bits[i],
77c975bc2b3fdc9674dd71a7ed89c74ce8ea2d46f0Peter Hawkins          errors::InvalidArgument(i, " is missing from 'perm' argument."));
781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    // 0-D, 1-D, and identity transposes do nothing.
811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    if (dims <= 1 || is_identity) {
821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ctx->SetOutput(0, ctx->Input(0));
831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      return;
841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    ctx->SetOutput(0,
871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                   ctx->builder()->Transpose(ctx->Input(0), transposed_order));
881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins};
901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
91c975bc2b3fdc9674dd71a7ed89c74ce8ea2d46f0Peter HawkinsREGISTER_XLA_OP(Name("Transpose").CompileTimeConstInput("perm"), TransposeOp);
921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// InvertPermutation frequently forms part of the gradient of Transpose.
941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins//
951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// inv = InvertPermutationOp(T<int32> p) takes a permutation of
961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// integers 0, 1, ..., n - 1 and returns the inverted
971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// permutation of p. I.e., inv[p[i]] == i, for i in [0 .. n).
981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins//
991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// REQUIRES: input is a vector of int32.
1001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// REQUIRES: input is a permutation of 0, 1, ..., n-1.
1011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsclass InvertPermutationOp : public XlaOpKernel {
1031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins public:
1041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  explicit InvertPermutationOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
1051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  void Compile(XlaOpKernelContext* ctx) override {
107c975bc2b3fdc9674dd71a7ed89c74ce8ea2d46f0Peter Hawkins    OP_REQUIRES(ctx,
108c975bc2b3fdc9674dd71a7ed89c74ce8ea2d46f0Peter Hawkins                FastBoundsCheck(ctx->InputShape(0).num_elements(),
109c975bc2b3fdc9674dd71a7ed89c74ce8ea2d46f0Peter Hawkins                                std::numeric_limits<int32>::max()),
1101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                errors::InvalidArgument("permutation of nonnegative int32s "
1111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                        "must have <= int32 max elements"));
1121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    std::vector<int64> perm;
1141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(0, &perm));
1151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    int size = perm.size();
1171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    std::vector<int32> output(size);
1191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    std::fill_n(output.data(), size, -1);
1201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    for (int i = 0; i < size; ++i) {
1211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      const int64 d = perm[i];
1221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      OP_REQUIRES(ctx, FastBoundsCheck(d, size),
1231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                  errors::InvalidArgument(d, " is not between 0 and ", size));
1241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      OP_REQUIRES(ctx, output[d] == -1,
1251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                  errors::InvalidArgument(d, " is duplicated in the input."));
1261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      output[d] = i;
1271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
1281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    ctx->SetOutput(0, ctx->builder()->ConstantR1<int32>(output));
1301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
1311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins};
1321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
133c975bc2b3fdc9674dd71a7ed89c74ce8ea2d46f0Peter HawkinsREGISTER_XLA_OP(Name("InvertPermutation")
134c975bc2b3fdc9674dd71a7ed89c74ce8ea2d46f0Peter Hawkins                    .TypeConstraint("T", DT_INT32)
135c975bc2b3fdc9674dd71a7ed89c74ce8ea2d46f0Peter Hawkins                    .CompileTimeConstInput("x"),
13693f9caba8e371bd2f55ec789ed2f8ece9b3d976dPeter Hawkins                InvertPermutationOp);
1371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}  // namespace
1391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}  // namespace tensorflow
140