1c8b59c046895fa5b6d79f73e0b5817330fcfbfc1A. Unique TensorFlower/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 210e62dc1e008d08cd877a0f891d486daba8f6288Vijay Vasudevan 310e62dc1e008d08cd877a0f891d486daba8f6288Vijay VasudevanLicensed under the Apache License, Version 2.0 (the "License"); 410e62dc1e008d08cd877a0f891d486daba8f6288Vijay Vasudevanyou may not use this file except in compliance with the License. 510e62dc1e008d08cd877a0f891d486daba8f6288Vijay VasudevanYou may obtain a copy of the License at 610e62dc1e008d08cd877a0f891d486daba8f6288Vijay Vasudevan 710e62dc1e008d08cd877a0f891d486daba8f6288Vijay Vasudevan http://www.apache.org/licenses/LICENSE-2.0 810e62dc1e008d08cd877a0f891d486daba8f6288Vijay Vasudevan 910e62dc1e008d08cd877a0f891d486daba8f6288Vijay VasudevanUnless required by applicable law or agreed to in writing, software 1010e62dc1e008d08cd877a0f891d486daba8f6288Vijay Vasudevandistributed under the License is distributed on an "AS IS" BASIS, 1110e62dc1e008d08cd877a0f891d486daba8f6288Vijay VasudevanWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 1210e62dc1e008d08cd877a0f891d486daba8f6288Vijay VasudevanSee the License for the specific language governing permissions and 1310e62dc1e008d08cd877a0f891d486daba8f6288Vijay Vasudevanlimitations under the License. 1410e62dc1e008d08cd877a0f891d486daba8f6288Vijay Vasudevan==============================================================================*/ 1510e62dc1e008d08cd877a0f891d486daba8f6288Vijay Vasudevan 160a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan#define EIGEN_USE_THREADS 170a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan 18b481783fe0e00a86f6feb20a8dcad5fc4fc936a4Josh Levenberg#include <vector> 190a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan#include "tensorflow/core/framework/op_kernel.h" 200a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan#include "tensorflow/core/framework/register_types.h" 210a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan#include "tensorflow/core/util/sparse/sparse_tensor.h" 220a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan 230a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevannamespace tensorflow { 240a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan 250a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevantemplate <typename T> 260a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevanclass SparseSplitOp : public OpKernel { 270a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan public: 280a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan explicit SparseSplitOp(OpKernelConstruction* context) : OpKernel(context) { 290a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan OP_REQUIRES_OK(context, context->GetAttr("num_split", &num_split_)); 300a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan } 310a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan 320a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan void Compute(OpKernelContext* context) override { 33769496bb0d3093a5531268635d364950c93327d6Patrick Nguyen const int64 split_dim = context->input(0).scalar<int64>()(); 340a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan const Tensor& input_indices = context->input(1); 350a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan const Tensor& input_values = context->input(2); 360a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan const Tensor& input_shape = context->input(3); 370a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan 380a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_indices.shape()), 390a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan errors::InvalidArgument( 4059f1eba5fb94506a205fa2e81145667754739da5Martin Wicke "Input indices should be a matrix but received shape ", 41d552be23658b3bdd1b7dedd34f25631773e81dffGeoffrey Irving input_indices.shape().DebugString())); 420a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan OP_REQUIRES(context, TensorShapeUtils::IsVector(input_values.shape()), 430a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan errors::InvalidArgument( 440a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan "Input values should be a vector but received shape ", 45d552be23658b3bdd1b7dedd34f25631773e81dffGeoffrey Irving input_indices.shape().DebugString())); 460a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan OP_REQUIRES(context, TensorShapeUtils::IsVector(input_shape.shape()), 470a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan errors::InvalidArgument( 480a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan "Input shape should be a vector but received shape ", 49d552be23658b3bdd1b7dedd34f25631773e81dffGeoffrey Irving input_shape.shape().DebugString())); 500a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan 51982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen OP_REQUIRES( 52982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen context, 53982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen input_shape.dim_size(0) && split_dim < input_shape.vec<int64>().size(), 54982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen errors::InvalidArgument( 55982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen "Input split_dim should be between 0 and rank (", 56982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen input_shape.vec<int64>().size(), "), got ", split_dim)); 570a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan 58982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen OP_REQUIRES( 59982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen context, 60982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen num_split_ >= 1 && num_split_ <= input_shape.vec<int64>()(split_dim), 61982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen errors::InvalidArgument("Input num_split should be between 1 " 62982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen "and the splitting dimension size (", 63982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen input_shape.vec<int64>()(split_dim), "), got ", 64982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen num_split_)); 650a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan 660a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan sparse::SparseTensor sparse_tensor(input_indices, input_values, 670a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan TensorShape(input_shape.vec<int64>())); 680a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan const std::vector<sparse::SparseTensor> outputs = 690a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan sparse::SparseTensor::Split<T>(sparse_tensor, split_dim, num_split_); 700a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan 710a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan for (int slice_index = 0; slice_index < num_split_; ++slice_index) { 720a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan context->set_output(slice_index, outputs[slice_index].indices()); 730a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan context->set_output(slice_index + num_split_, 740a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan outputs[slice_index].values()); 750a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan Tensor* shape = nullptr; 76cade141580c76b41ba71bdc4b019722e674ab954Eugene Brevdo OP_REQUIRES_OK(context, context->allocate_output( 77cade141580c76b41ba71bdc4b019722e674ab954Eugene Brevdo slice_index + 2 * num_split_, 78cade141580c76b41ba71bdc4b019722e674ab954Eugene Brevdo {outputs[slice_index].dims()}, &shape)); 79cade141580c76b41ba71bdc4b019722e674ab954Eugene Brevdo auto output_shape = outputs[slice_index].shape(); 80cade141580c76b41ba71bdc4b019722e674ab954Eugene Brevdo for (int dim = 0; dim < outputs[slice_index].dims(); ++dim) { 81cade141580c76b41ba71bdc4b019722e674ab954Eugene Brevdo shape->vec<int64>()(dim) = output_shape[dim]; 820a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan } 830a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan } 840a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan } 850a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan 860a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan private: 870a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan int num_split_; 880a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan}; 890a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan 900a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan#define REGISTER_KERNELS(type) \ 910a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan REGISTER_KERNEL_BUILDER( \ 920a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan Name("SparseSplit").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ 930a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan SparseSplitOp<type>) 940a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan 950a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay VasudevanTF_CALL_ALL_TYPES(REGISTER_KERNELS); 960a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan#undef REGISTER_KERNELS 970a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan 980a21a38d4ef5b66177f407f74f14dd7b72232b36Vijay Vasudevan} // namespace tensorflow 99