1b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower 3b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlowerLicensed under the Apache License, Version 2.0 (the "License"); 4b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFloweryou may not use this file except in compliance with the License. 5b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlowerYou may obtain a copy of the License at 6b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower 7b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower http://www.apache.org/licenses/LICENSE-2.0 8b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower 9b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlowerUnless required by applicable law or agreed to in writing, software 10b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlowerdistributed under the License is distributed on an "AS IS" BASIS, 11b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlowerWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlowerSee the License for the specific language governing permissions and 13b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlowerlimitations under the License. 14b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower==============================================================================*/ 15b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower 16b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower#ifndef TENSORFLOW_KERNELS_SHAPE_OPS_H_ 17b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower#define TENSORFLOW_KERNELS_SHAPE_OPS_H_ 18b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower 19b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower#include <limits> 20b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower#include <unordered_set> 21b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower#include <vector> 22b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower 23b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower#include "tensorflow/core/framework/op_kernel.h" 24b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower#include "tensorflow/core/framework/tensor.h" 25b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower#include "tensorflow/core/framework/tensor_shape.h" 2681ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdo#include "tensorflow/core/framework/variant_op_registry.h" 27b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower#include "tensorflow/core/kernels/bounds_check.h" 28b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower 29b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlowernamespace tensorflow { 30b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower 3181ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdonamespace shape_op_helpers { 3281ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdoinline Status GetRegularOrVariantShape(OpKernelContext* ctx, int input_index, 3381ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdo TensorShape* shape) { 3481ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdo const Tensor& inp = ctx->input(input_index); 3581ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdo if (ctx->input_dtype(0) == DT_VARIANT) { 3681ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdo if (inp.dims() != 0) { 3781ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdo return errors::InvalidArgument( 380acf5bb38a8f208c6d9f048579a076d5bc6ff0beEugene Brevdo "Shape of non-unary Variant not supported."); 3981ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdo } 4081ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdo TF_RETURN_IF_ERROR(GetUnaryVariantShape(inp, shape)); 4181ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdo } else { 4281ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdo *shape = inp.shape(); 4381ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdo } 4481ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdo return Status::OK(); 4581ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdo} 4681ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdo} // namespace shape_op_helpers 4781ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdo 48b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlowertemplate <typename OutType> 49b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlowerclass ShapeOp : public OpKernel { 50b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower public: 51b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower explicit ShapeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 52b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower 53b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower void Compute(OpKernelContext* ctx) override { 5481ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdo TensorShape shape; 5581ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdo OP_REQUIRES_OK(ctx, 5681ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdo shape_op_helpers::GetRegularOrVariantShape(ctx, 0, &shape)); 5781ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdo const int rank = shape.dims(); 58b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower Tensor* out = nullptr; 59b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({rank}), &out)); 60b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower auto vec = out->vec<OutType>(); 61b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower for (int i = 0; i < rank; ++i) { 6281ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdo int64 dim_size = shape.dim_size(i); 63b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower if (out->dtype() == DT_INT32) { 64b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower OP_REQUIRES( 65b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower ctx, FastBoundsCheck(dim_size, std::numeric_limits<int32>::max()), 66b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower errors::InvalidArgument("Shape output type is 32-bit ", " but dim ", 67b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower i, " is ", dim_size)); 68b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower } 69b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower vec(i) = static_cast<OutType>(dim_size); 70b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower } 71b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower } 72b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower 73b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower bool IsExpensive() override { return false; } 74b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower}; 75b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower 76b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlowertemplate <typename OutType> 77b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlowerclass ShapeNOp : public OpKernel { 78b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower public: 79b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower explicit ShapeNOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 80b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower 81b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower void Compute(OpKernelContext* ctx) override { 82b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower for (int i = 0; i < ctx->num_inputs(); ++i) { 8381ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdo TensorShape shape; 8481ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdo OP_REQUIRES_OK( 8581ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdo ctx, shape_op_helpers::GetRegularOrVariantShape(ctx, i, &shape)); 86b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower const int dims = shape.dims(); 87b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower Tensor* out = nullptr; 88b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower OP_REQUIRES_OK(ctx, ctx->allocate_output(i, {dims}, &out)); 89b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower auto vec = out->vec<OutType>(); 90b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower 91b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower for (int j = 0; j < dims; ++j) { 92b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower int64 dim_size = shape.dim_size(j); 93b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower if (out->dtype() == DT_INT32) { 94b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower OP_REQUIRES( 95b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower ctx, FastBoundsCheck(dim_size, std::numeric_limits<int32>::max()), 96b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower errors::InvalidArgument("ShapeN output type is 32-bit but shape ", 97b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower i, " dim ", j, " is ", dim_size)); 98b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower } 99b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower vec(j) = static_cast<OutType>(dim_size); 100b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower } 101b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower } 102b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower } 103b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower 104b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower bool IsExpensive() override { return false; } 105b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower}; 106b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower 107b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlowerclass RankOp : public OpKernel { 108b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower public: 109b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower explicit RankOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 110b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower 111b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower void Compute(OpKernelContext* ctx) override { 11281ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdo TensorShape shape; 11381ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdo OP_REQUIRES_OK(ctx, 11481ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdo shape_op_helpers::GetRegularOrVariantShape(ctx, 0, &shape)); 11581ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdo const int rank = shape.dims(); 116b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower Tensor* out = nullptr; 117b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &out)); 118b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower out->scalar<int32>()() = rank; 119b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower } 120b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower 121b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower bool IsExpensive() override { return false; } 122b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower}; 123b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower 124b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlowertemplate <typename OutType> 125b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlowerclass SizeOp : public OpKernel { 126b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower public: 127b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower explicit SizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 128b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower 129b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower void Compute(OpKernelContext* ctx) override { 13081ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdo TensorShape shape; 13181ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdo OP_REQUIRES_OK(ctx, 13281ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdo shape_op_helpers::GetRegularOrVariantShape(ctx, 0, &shape)); 13381ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdo const int64 size = shape.num_elements(); 134b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower Tensor* out = nullptr; 135b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &out)); 136b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower if (out->dtype() == DT_INT32) { 137b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower OP_REQUIRES( 138b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower ctx, FastBoundsCheck(size, std::numeric_limits<int32>::max()), 139b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower errors::InvalidArgument("Number of elements was larger than " 140b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower "representable by 32-bit output type")); 141b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower } 142b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower out->scalar<OutType>()() = static_cast<OutType>(size); 143b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower } 144b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower 145b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower bool IsExpensive() override { return false; } 146b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower}; 147b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower 148d0a5d885d61b837018cb931a4d577289acc826fcMartin Wicketemplate <typename Tdim> 149b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlowerclass ExpandDimsOp : public OpKernel { 150b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower public: 151b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower explicit ExpandDimsOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 152b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower 153b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower void Compute(OpKernelContext* ctx) override { 15481ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdo OP_REQUIRES(ctx, ctx->input(0).dtype() != DT_VARIANT, 15581ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdo errors::InvalidArgument("ExpandDims on Variant not supported")); 15681ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdo 157d0a5d885d61b837018cb931a4d577289acc826fcMartin Wicke Tdim dim = ctx->input(1).flat<Tdim>()(0); 158b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower OP_REQUIRES( 159b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower ctx, (dim >= -1 - ctx->input(0).dims() && dim <= ctx->input(0).dims()), 160b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower errors::InvalidArgument("Tried to expand dim index ", dim, 161b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower " for tensor with ", ctx->input(0).dims(), 162b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower " dimensions.")); 163b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower 164b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower auto existing_dims = ctx->input(0).shape().dim_sizes(); 165b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower // Safe - # elements in tensor dims bounded. 166b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower const int existing_dims_size = static_cast<int>(existing_dims.size()); 167b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower std::vector<int64> new_shape(existing_dims_size); 168b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower for (size_t i = 0; i < new_shape.size(); ++i) { 169b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower new_shape[i] = existing_dims[i]; 170b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower } 171b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower 172b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower // We emulate numpy's interpretation of the dim axis when 173b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower // -input.dims() >= dim <= input.dims(). 174b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower if (dim < 0) { 175b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower dim += existing_dims.size() + 1; 176b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower } 177b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower 178b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower // Clamp to the end if needed. 179d0a5d885d61b837018cb931a4d577289acc826fcMartin Wicke dim = std::min<Tdim>(dim, existing_dims_size); 180b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower new_shape.emplace(new_shape.begin() + dim, 1); 181b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower const TensorShape output_shape(new_shape); 182b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower 183b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower Tensor* output = nullptr; 184b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {0}, &output)); 185b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower if (!output->CopyFrom(ctx->input(0), output_shape)) { 186b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower // This should never happen, since the sizes of the input and output 187b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower // should always be the same (we only expand the dimension with 1). 188b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower ctx->SetStatus( 189b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower errors::Internal("Could not expand dimension with input shape ", 190b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower ctx->input(0).shape().DebugString(), 191b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower " and output shape ", output_shape.DebugString())); 192b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower } 193b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower } 194b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower 195b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower bool IsExpensive() override { return false; } 196b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower}; 197b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower 198b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlowerclass SqueezeOp : public OpKernel { 199b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower public: 200b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower explicit SqueezeOp(OpKernelConstruction* ctx) : OpKernel(ctx) { 201b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower std::vector<int32> squeeze_dims; 202b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower OP_REQUIRES_OK(ctx, ctx->GetAttr("squeeze_dims", &squeeze_dims)); 203b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower squeeze_dims_.insert(squeeze_dims.begin(), squeeze_dims.end()); 204b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower } 205b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower 206b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower void Compute(OpKernelContext* ctx) override { 20781ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdo OP_REQUIRES(ctx, ctx->input(0).dtype() != DT_VARIANT, 20881ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdo errors::InvalidArgument("Squeeze on Variant not supported")); 20981ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdo 210b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower auto existing_dims = ctx->input(0).shape().dim_sizes(); 211b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower const int existing_dims_size = static_cast<int>(existing_dims.size()); 212b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower std::vector<int64> new_shape; 213b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower 214b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower std::unordered_set<int32> wrapped_squeeze_dims; 215b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower wrapped_squeeze_dims.reserve(squeeze_dims_.size()); 216b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower // Validate squeeze dims against the input. 217b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower for (int32 dim : squeeze_dims_) { 218b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower OP_REQUIRES( 219b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower ctx, (dim >= -ctx->input(0).dims() && dim < ctx->input(0).dims()), 220b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower errors::InvalidArgument("Tried to squeeze dim index ", dim, 221b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower " for tensor with ", ctx->input(0).dims(), 222b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower " dimensions.")); 223b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower // If dim is < 0, we wrap around (-1 means the last element). 224b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower if (dim < 0) { 225b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower dim = existing_dims_size + dim; 226b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower } 227b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower 228b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower wrapped_squeeze_dims.insert(dim); 229b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower } 230b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower 231b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower for (int i = 0; i < existing_dims_size; ++i) { 232b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower auto existing_dim = existing_dims[i]; 233b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower 234b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower // If squeeze_set is non-empty, only squeeze those dimensions. 235b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower if (!wrapped_squeeze_dims.empty()) { 236b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower if (wrapped_squeeze_dims.count(i) > 0) { 237b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower OP_REQUIRES(ctx, existing_dim == 1, 238982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen errors::InvalidArgument( 239982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen "Tried to explicitly squeeze " 240982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen "dimension ", 241982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen i, " but dimension was not 1: ", existing_dim)); 242b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower } else { 243b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower // This dimension is not being squeezed. 244b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower new_shape.push_back(existing_dim); 245b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower } 246b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower } else { 247b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower // Copy over all non-1-length dimensions. 248b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower if (existing_dim != 1) { 249b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower new_shape.push_back(existing_dim); 250b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower } 251b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower } 252b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower } 253b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower 254b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower const TensorShape output_shape(new_shape); 255b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower Tensor* output = nullptr; 256b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {0}, &output)); 257b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower if (!output->CopyFrom(ctx->input(0), output_shape)) { 258b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower // This should never happen, since the sizes of the input and 259b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower // output should always be the same. 260b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower ctx->SetStatus(errors::Internal("Could not squeeze input with shape ", 261b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower ctx->input(0).shape().DebugString(), 262b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower " and output shape ", 263b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower output_shape.DebugString())); 264b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower } 265b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower } 266b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower 267b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower bool IsExpensive() override { return false; } 268b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower 269b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower private: 270b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower std::unordered_set<int32> squeeze_dims_; 271b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower}; 272b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower 273b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower} // namespace tensorflow 274b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower 275b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower#endif // TENSORFLOW_KERNELS_SHAPE_OPS_H_ 276