1c8b59c046895fa5b6d79f73e0b5817330fcfbfc1A. Unique TensorFlower/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
210e62dc1e008d08cd877a0f891d486daba8f6288Vijay Vasudevan
310e62dc1e008d08cd877a0f891d486daba8f6288Vijay VasudevanLicensed under the Apache License, Version 2.0 (the "License");
410e62dc1e008d08cd877a0f891d486daba8f6288Vijay Vasudevanyou may not use this file except in compliance with the License.
510e62dc1e008d08cd877a0f891d486daba8f6288Vijay VasudevanYou may obtain a copy of the License at
610e62dc1e008d08cd877a0f891d486daba8f6288Vijay Vasudevan
710e62dc1e008d08cd877a0f891d486daba8f6288Vijay Vasudevan    http://www.apache.org/licenses/LICENSE-2.0
810e62dc1e008d08cd877a0f891d486daba8f6288Vijay Vasudevan
910e62dc1e008d08cd877a0f891d486daba8f6288Vijay VasudevanUnless required by applicable law or agreed to in writing, software
1010e62dc1e008d08cd877a0f891d486daba8f6288Vijay Vasudevandistributed under the License is distributed on an "AS IS" BASIS,
1110e62dc1e008d08cd877a0f891d486daba8f6288Vijay VasudevanWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1210e62dc1e008d08cd877a0f891d486daba8f6288Vijay VasudevanSee the License for the specific language governing permissions and
1310e62dc1e008d08cd877a0f891d486daba8f6288Vijay Vasudevanlimitations under the License.
1410e62dc1e008d08cd877a0f891d486daba8f6288Vijay Vasudevan==============================================================================*/
1510e62dc1e008d08cd877a0f891d486daba8f6288Vijay Vasudevan
160a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan#define EIGEN_USE_THREADS
170a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan
18b481783fe0e00a86f6feb20a8dcad5fc4fc936a4Josh Levenberg#include <vector>
190a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan#include "tensorflow/core/framework/op_kernel.h"
200a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan#include "tensorflow/core/framework/register_types.h"
210a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan#include "tensorflow/core/util/sparse/sparse_tensor.h"
220a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan
230a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevannamespace tensorflow {
240a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan
250a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevantemplate <typename T>
260a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevanclass SparseSplitOp : public OpKernel {
270a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan public:
280a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan  explicit SparseSplitOp(OpKernelConstruction* context) : OpKernel(context) {
290a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan    OP_REQUIRES_OK(context, context->GetAttr("num_split", &num_split_));
300a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan  }
310a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan
320a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan  void Compute(OpKernelContext* context) override {
33769496bb0d3093a5531268635d364950c93327d6Patrick Nguyen    const int64 split_dim = context->input(0).scalar<int64>()();
340a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan    const Tensor& input_indices = context->input(1);
350a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan    const Tensor& input_values = context->input(2);
360a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan    const Tensor& input_shape = context->input(3);
370a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan
380a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan    OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_indices.shape()),
390a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan                errors::InvalidArgument(
4059f1eba5fb94506a205fa2e81145667754739da5Martin Wicke                    "Input indices should be a matrix but received shape ",
41d552be23658b3bdd1b7dedd34f25631773e81dffGeoffrey Irving                    input_indices.shape().DebugString()));
420a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan    OP_REQUIRES(context, TensorShapeUtils::IsVector(input_values.shape()),
430a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan                errors::InvalidArgument(
440a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan                    "Input values should be a vector but received shape ",
45d552be23658b3bdd1b7dedd34f25631773e81dffGeoffrey Irving                    input_indices.shape().DebugString()));
460a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan    OP_REQUIRES(context, TensorShapeUtils::IsVector(input_shape.shape()),
470a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan                errors::InvalidArgument(
480a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan                    "Input shape should be a vector but received shape ",
49d552be23658b3bdd1b7dedd34f25631773e81dffGeoffrey Irving                    input_shape.shape().DebugString()));
500a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan
51982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen    OP_REQUIRES(
52982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen        context,
53982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen        input_shape.dim_size(0) && split_dim < input_shape.vec<int64>().size(),
54982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen        errors::InvalidArgument(
55982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen            "Input split_dim should be between 0 and rank (",
56982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen            input_shape.vec<int64>().size(), "), got ", split_dim));
570a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan
58982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen    OP_REQUIRES(
59982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen        context,
60982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen        num_split_ >= 1 && num_split_ <= input_shape.vec<int64>()(split_dim),
61982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen        errors::InvalidArgument("Input num_split should be between 1 "
62982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                                "and the splitting dimension size (",
63982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                                input_shape.vec<int64>()(split_dim), "), got ",
64982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                                num_split_));
650a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan
660a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan    sparse::SparseTensor sparse_tensor(input_indices, input_values,
670a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan                                       TensorShape(input_shape.vec<int64>()));
680a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan    const std::vector<sparse::SparseTensor> outputs =
690a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan        sparse::SparseTensor::Split<T>(sparse_tensor, split_dim, num_split_);
700a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan
710a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan    for (int slice_index = 0; slice_index < num_split_; ++slice_index) {
720a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan      context->set_output(slice_index, outputs[slice_index].indices());
730a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan      context->set_output(slice_index + num_split_,
740a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan                          outputs[slice_index].values());
750a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan      Tensor* shape = nullptr;
76cade141580c76b41ba71bdc4b019722e674ab954Eugene Brevdo      OP_REQUIRES_OK(context, context->allocate_output(
77cade141580c76b41ba71bdc4b019722e674ab954Eugene Brevdo                                  slice_index + 2 * num_split_,
78cade141580c76b41ba71bdc4b019722e674ab954Eugene Brevdo                                  {outputs[slice_index].dims()}, &shape));
79cade141580c76b41ba71bdc4b019722e674ab954Eugene Brevdo      auto output_shape = outputs[slice_index].shape();
80cade141580c76b41ba71bdc4b019722e674ab954Eugene Brevdo      for (int dim = 0; dim < outputs[slice_index].dims(); ++dim) {
81cade141580c76b41ba71bdc4b019722e674ab954Eugene Brevdo        shape->vec<int64>()(dim) = output_shape[dim];
820a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan      }
830a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan    }
840a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan  }
850a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan
860a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan private:
870a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan  int num_split_;
880a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan};
890a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan
900a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan#define REGISTER_KERNELS(type)                                          \
910a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan  REGISTER_KERNEL_BUILDER(                                              \
920a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan      Name("SparseSplit").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
930a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan      SparseSplitOp<type>)
940a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan
950a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay VasudevanTF_CALL_ALL_TYPES(REGISTER_KERNELS);
960a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan#undef REGISTER_KERNELS
970a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan
980a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan}  // namespace tensorflow
99