1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#define EIGEN_USE_THREADS
17
18#include <algorithm>
19#include <numeric>
20#include <unordered_map>
21#include <utility>
22#include <vector>
23
24#include "tensorflow/core/framework/op_kernel.h"
25#include "tensorflow/core/framework/register_types.h"
26#include "tensorflow/core/framework/tensor.h"
27#include "tensorflow/core/framework/tensor_util.h"
28#include "tensorflow/core/framework/types.h"
29#include "tensorflow/core/lib/gtl/inlined_vector.h"
30#include "tensorflow/core/util/sparse/sparse_tensor.h"
31
32namespace tensorflow {
33
34template <typename T>
35class SparseConcatOp : public OpKernel {
36 public:
37  explicit SparseConcatOp(OpKernelConstruction* context) : OpKernel(context) {
38    OP_REQUIRES_OK(context, context->GetAttr("concat_dim", &concat_dim_attr_));
39  }
40
41  void Compute(OpKernelContext* context) override {
42    OpInputList inds;
43    OP_REQUIRES_OK(context, context->input_list("indices", &inds));
44    const int N = inds.size();
45    for (int i = 0; i < N; i++) {
46      OP_REQUIRES(context, TensorShapeUtils::IsMatrix(inds[i].shape()),
47                  errors::InvalidArgument(
48                      "Input indices should be a matrix but received shape ",
49                      inds[i].shape().DebugString(), " at position ", i));
50    }
51
52    OpInputList vals;
53    OP_REQUIRES_OK(context, context->input_list("values", &vals));
54    OP_REQUIRES(context, vals.size() == N,
55                errors::InvalidArgument("Expected ", N, " input values, got ",
56                                        vals.size()));
57    for (int i = 0; i < N; i++) {
58      OP_REQUIRES(context, TensorShapeUtils::IsVector(vals[i].shape()),
59                  errors::InvalidArgument(
60                      "Input values should be a vector but received shape ",
61                      vals[i].shape().DebugString(), " at position ", i));
62    }
63
64    OpInputList shapes;
65    OP_REQUIRES_OK(context, context->input_list("shapes", &shapes));
66    OP_REQUIRES(context, shapes.size() == N,
67                errors::InvalidArgument("Expected ", N, " input shapes, got ",
68                                        shapes.size()));
69    for (int i = 0; i < N; i++) {
70      OP_REQUIRES(context, TensorShapeUtils::IsVector(shapes[i].shape()),
71                  errors::InvalidArgument(
72                      "Input shapes should be a vector but received shape ",
73                      shapes[i].shape().DebugString(), " at position ", i));
74    }
75
76    const TensorShape input_shape(shapes[0].vec<int64>());
77    const int input_rank = input_shape.dims();
78    const int concat_dim = (concat_dim_attr_ < 0)
79                               ? input_rank + concat_dim_attr_
80                               : concat_dim_attr_;
81    OP_REQUIRES(context, concat_dim >= 0 && concat_dim < input_rank,
82                errors::InvalidArgument("Concat dimension must be in range [",
83                                        -input_rank, ", ", input_rank,
84                                        "), got ", concat_dim_attr_));
85    for (int i = 1; i < N; ++i) {
86      const TensorShape current_shape(shapes[i].vec<int64>());
87      OP_REQUIRES(
88          context, current_shape.dims() == input_rank,
89          errors::InvalidArgument(
90              "Ranks of all input tensors must match: expected ", input_rank,
91              " but got ", current_shape.dims(), " at position ", i));
92      for (int j = 0; j < input_rank; ++j) {
93        if (j != concat_dim) {
94          OP_REQUIRES(
95              context, input_shape.dim_size(j) == current_shape.dim_size(j),
96              errors::InvalidArgument(
97                  "Input shapes must match: expected ", input_shape.dim_size(j),
98                  " for dimension ", j, " but got ", current_shape.dim_size(j),
99                  " at position ", i));
100        }
101      }
102    }
103
104    // The input and output sparse tensors are assumed to be ordered along
105    // increasing dimension number. But in order for concat to work properly,
106    // order[0] must be concat_dim. So we will reorder the inputs to the
107    // concat ordering, concatenate, then reorder back to the standard order.
108    // We make a deep copy of the input tensors to ensure that the in-place
109    // reorder doesn't create race conditions for other ops that may be
110    // concurrently reading the indices and values tensors.
111
112    gtl::InlinedVector<int64, 8> std_order(input_rank);
113    std::iota(std_order.begin(), std_order.end(), 0);
114
115    std::vector<int64> concat_order;
116    concat_order.reserve(input_rank);
117    concat_order.push_back(concat_dim);
118    for (int j = 0; j < input_rank; ++j) {
119      if (j != concat_dim) {
120        concat_order.push_back(j);
121      }
122    }
123
124    std::vector<sparse::SparseTensor> sp_inputs;
125    for (int i = 0; i < N; ++i) {
126      const TensorShape current_shape(shapes[i].vec<int64>());
127      sp_inputs.emplace_back(tensor::DeepCopy(inds[i]),
128                             tensor::DeepCopy(vals[i]), current_shape,
129                             std_order);
130      sp_inputs[i].Reorder<T>(concat_order);
131    }
132
133    sparse::SparseTensor concat = sparse::SparseTensor::Concat<T>(sp_inputs);
134    concat.Reorder<T>(std_order);
135
136    context->set_output(0, concat.indices());
137    context->set_output(1, concat.values());
138
139    Tensor* output_shape_out = nullptr;
140    OP_REQUIRES_OK(context,
141                   context->allocate_output(2, TensorShape({concat.dims()}),
142                                            &output_shape_out));
143    auto output_shape = output_shape_out->vec<int64>();
144    auto concat_shape = concat.shape();
145    for (int j = 0; j < concat.dims(); ++j) {
146      output_shape(j) = concat_shape[j];
147    }
148  }
149
150 private:
151  int concat_dim_attr_;
152};
153
154#define REGISTER_KERNELS(type)                                           \
155  REGISTER_KERNEL_BUILDER(                                               \
156      Name("SparseConcat").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
157      SparseConcatOp<type>)
158
159TF_CALL_ALL_TYPES(REGISTER_KERNELS);
160#undef REGISTER_KERNELS
161}  // namespace tensorflow
162