1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_UTIL_SPARSE_SPARSE_TENSOR_H_
17#define TENSORFLOW_UTIL_SPARSE_SPARSE_TENSOR_H_
18
19#include <limits>
20#include <numeric>
21#include <vector>
22
23#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
24#include "tensorflow/core/framework/tensor.h"
25#include "tensorflow/core/framework/tensor_types.h"
26#include "tensorflow/core/framework/types.h"
27#include "tensorflow/core/framework/types.pb.h"
28#include "tensorflow/core/kernels/bounds_check.h"
29#include "tensorflow/core/lib/core/status.h"
30#include "tensorflow/core/lib/strings/str_util.h"
31#include "tensorflow/core/platform/logging.h"
32#include "tensorflow/core/platform/types.h"
33#include "tensorflow/core/util/sparse/dim_comparator.h"
34#include "tensorflow/core/util/sparse/group_iterator.h"
35
36namespace tensorflow {
37namespace sparse {
38
39class SparseTensor {
40 public:
41  typedef typename gtl::ArraySlice<int64> VarDimArray;
42  typedef typename gtl::InlinedVector<int64, 8> ShapeArray;
43
44  SparseTensor(Tensor ix, Tensor vals, const TensorShape& shape)
45      : SparseTensor(ix, vals, TensorShapeToVector(shape),
46                     UndefinedOrder(TensorShapeToVector(shape))) {}
47
48  SparseTensor(Tensor ix, Tensor vals, const VarDimArray shape)
49      : SparseTensor(ix, vals, shape, UndefinedOrder(shape)) {}
50
51  SparseTensor(Tensor ix, Tensor vals, const TensorShape& shape,
52               const VarDimArray order)
53      : SparseTensor(ix, vals, TensorShapeToVector(shape), order) {}
54
55  SparseTensor(Tensor ix, Tensor vals, const VarDimArray shape,
56               const VarDimArray order)
57      : ix_(ix),
58        vals_(vals),
59        shape_(shape.begin(), shape.end()),
60        order_(order.begin(), order.end()),
61        dims_(GetDimsFromIx(ix)) {
62    CHECK_EQ(ix.dtype(), DT_INT64)
63        << "indices must be type int64 but got: " << ix.dtype();
64    CHECK(TensorShapeUtils::IsVector(vals.shape()))
65        << "vals must be a vec, but got: " << vals.shape().DebugString();
66    CHECK_EQ(ix.shape().dim_size(0), vals.shape().dim_size(0))
67        << "indices and values rows (indexing dimension) must match.";
68    CHECK_EQ(order.size(), dims_) << "Order length must be SparseTensor rank.";
69    CHECK_EQ(shape.size(), dims_) << "Shape rank must be SparseTensor rank.";
70  }
71
72  SparseTensor(const SparseTensor& other)
73      : SparseTensor(other.ix_, other.vals_, other.shape_, other.order_) {}
74
75  SparseTensor(SparseTensor&& other)
76      : SparseTensor(std::move(other.ix_), std::move(other.vals_),
77                     std::move(other.shape_), std::move(other.order_)) {}
78
79  SparseTensor& operator=(const SparseTensor& other) {
80    ix_ = other.ix_;
81    vals_ = other.vals_;
82    shape_ = other.shape_;
83    order_ = other.order_;
84    return *this;
85  }
86
87  std::size_t num_entries() const { return ix_.dim_size(0); }
88
89  int dims() const { return shape_.size(); }
90
91  const Tensor& indices() const { return ix_; }
92
93  const Tensor& values() const { return vals_; }
94
95  DataType dtype() const { return vals_.dtype(); }
96
97  Status IndicesValid() const {
98    const auto ix_t = ix_.matrix<int64>();
99    for (int64 ord : order_) {
100      if (ord < 0) {
101        return errors::FailedPrecondition(
102            "Order was not provided.  Provide an order at "
103            "construction time or run ReorderInPlace");
104      }
105    }
106
107    for (std::size_t n = 0; n < num_entries(); ++n) {
108      TF_RETURN_IF_ERROR(IndexValid(ix_t, n));
109    }
110
111    return Status::OK();
112  }
113
114  VarDimArray shape() const { return shape_; }
115
116  VarDimArray order() const { return order_; }
117
118  // Resorts the indices and values according to the dimensions in order.
119  template <typename T>
120  void Reorder(const VarDimArray& order);
121
122  // Returns a group iterable that can be used for clumping indices
123  // and values according to the group indices of interest.
124  //
125  // Precondition: order()[0..group_ix.size()] == group_ix.
126  //
127  // See the README.md in this directory for more usage information.
128  GroupIterable group(const VarDimArray& group_ix) const {
129    CHECK_LE(group_ix.size(), dims_);
130    for (std::size_t di = 0; di < group_ix.size(); ++di) {
131      CHECK_GE(group_ix[di], 0) << "Group dimension out of range";
132      CHECK_LT(group_ix[di], dims_) << "Group dimension out of range";
133      CHECK_EQ(group_ix[di], order_[di])
134          << "Group dimension does not match sorted order";
135    }
136    return GroupIterable(ix_, vals_, dims_, group_ix);
137  }
138
139  // Stores the sparse indices into the dense tensor out.
140  // Preconditions:
141  //   out->shape().dims() == shape().dims()
142  //   out->shape().dim_size(d) >= shape(d) for all d
143  //
144  // Returns true on success.  False on failure (mismatched dimensions
145  // or out-of-bounds indices).
146  //
147  // If initialize==True, ToDense first overwrites all coefficients in out to 0.
148  //
149  template <typename T>
150  bool ToDense(Tensor* out, bool initialize = true);
151
152  // Concat() will concatenate all the tensors according to their first order
153  // dimension.  All tensors must have identical shape except for
154  // the first order dimension.  All tensors orders' first dimension
155  // must match.
156  //
157  // If all of the tensors have identical ordering, then the output
158  // will have this ordering.  Otherwise the output is set as not
159  // having any order and a Reorder<T>() should be called on it before
160  // performing any subsequent operations.
161  template <typename T>
162  static SparseTensor Concat(const gtl::ArraySlice<SparseTensor>& tensors);
163
164  // Split() will split the input SparseTensor into a list of num_split
165  // SparseTensor given a splitting dimension. If the input dimension range
166  // isn't an integer multiple of split_dim, we add one extra dimension for
167  // each slice.
168  template <typename T>
169  static std::vector<SparseTensor> Split(const SparseTensor& tensor,
170                                         const int split_dim,
171                                         const int num_split);
172
173  // Slice() will slice the input SparseTensor into a SparseTensor based on
174  // specified start and size. Both start and size are 1-D array with each
175  // element of the array representing one dimension. The start is the start
176  // index at each dimension and the size is the size at each dimension.
177  template <typename T>
178  static SparseTensor Slice(const SparseTensor& tensor,
179                            const gtl::ArraySlice<int64>& start,
180                            const gtl::ArraySlice<int64>& size);
181
182  // Picks out the dimensions according to `dim_indices`.
183  std::vector<int64> PickDims(gtl::ArraySlice<int64> dim_indices) const {
184    std::vector<int64> res(dim_indices.size());
185    for (size_t i = 0; i < dim_indices.size(); ++i) {
186      res[i] = shape_[dim_indices[i]];
187    }
188    return res;
189  }
190
191 private:
192  static int GetDimsFromIx(const Tensor& ix) {
193    CHECK(TensorShapeUtils::IsMatrix(ix.shape()))
194        << "indices must be a matrix, but got: " << ix.shape().DebugString();
195    return ix.dim_size(1);
196  }
197
198  static inline ShapeArray UndefinedOrder(const VarDimArray shape) {
199    return ShapeArray(shape.size(), -1);
200  }
201
202  static inline ShapeArray TensorShapeToVector(const TensorShape& shape) {
203    ShapeArray vec(shape.dims());
204    for (int i = 0; i < shape.dims(); ++i) vec[i] = shape.dim_size(i);
205    return vec;
206  }
207
208  // Helper for IndicesValid()
209  inline Status IndexValid(const TTypes<int64>::ConstMatrix& ix_t,
210                           int n) const {
211    bool valid = true;
212    bool different = false;
213    bool increasing = true;
214    if (n == 0) {
215      for (int di = 0; di < dims_; ++di) {
216        if (ix_t(n, di) < 0 || ix_t(n, di) >= shape_[di]) valid = false;
217      }
218      different = true;
219    } else {
220      for (int di = 0; di < dims_; ++di) {
221        if (ix_t(n, di) < 0 || ix_t(n, di) >= shape_[di]) valid = false;
222        int64 diff = ix_t(n, order_[di]) - ix_t(n - 1, order_[di]);
223        if (diff > 0) different = true;
224        if (!different && diff < 0) increasing = false;
225      }
226    }
227    if (TF_PREDICT_FALSE(!valid || !increasing || !different)) {
228      string index = strings::StrCat("indices[", n, "] = [");
229      for (int di = 0; di < dims_; ++di) {
230        strings::StrAppend(&index, ix_t(n, di), di < dims_ - 1 ? "," : "]");
231      }
232      if (!valid) {
233        return errors::InvalidArgument(index,
234                                       " is out of bounds: need 0 <= index < [",
235                                       str_util::Join(shape_, ","), "]");
236      }
237      if (!increasing) {
238        return errors::InvalidArgument(index, " is out of order");
239      }
240      if (!different) {
241        return errors::InvalidArgument(index, " is repeated");
242      }
243    }
244    return Status::OK();
245  }
246
247  // Helper for ToDense<T>()
248  template <typename T>
249  bool ValidateAndInitializeToDense(Tensor* out, bool initialize);
250
251  // Helper for Split() that returns the slice index.
252  static inline int GetSliceIndex(const int dim, const int split_size,
253                                  const int residual) {
254    CHECK_GT(split_size, 0);
255    CHECK_GE(dim, 0);
256    if (residual == 0) return dim / split_size;
257    const int offset = residual * (split_size + 1);
258    if (dim < offset) {
259      return dim / (split_size + 1);
260    } else {
261      return residual + ((dim - offset) / split_size);
262    }
263  }
264
265  // Helper for Split() that returns the dimension in the slice.
266  static inline int GetDimensionInSlice(const int dim, const int split_size,
267                                        const int residual) {
268    CHECK_GT(split_size, 0);
269    CHECK_GE(dim, 0);
270    if (residual == 0) return dim % split_size;
271    const int offset = residual * (split_size + 1);
272    if (dim < offset) {
273      return dim % (split_size + 1);
274    } else {
275      return (dim - offset) % split_size;
276    }
277  }
278
279  // Helper for Split() that returns the shape given a slice index.
280  static inline int GetSliceShape(const int slice_index, const int split_size,
281                                  const int residual) {
282    CHECK_GT(split_size, 0);
283    CHECK_GE(slice_index, 0);
284    if (residual == 0) return split_size;
285    if (slice_index < residual) {
286      return split_size + 1;
287    } else {
288      return split_size;
289    }
290  }
291
292  Tensor ix_;
293  Tensor vals_;
294  ShapeArray shape_;
295  ShapeArray order_;
296  const int dims_;
297};
298
299// This operation updates the indices and values Tensor rows, so it is
300// an in-place algorithm.  It requires O(N log N) time and O(N)
301// temporary space.
302template <typename T>
303void SparseTensor::Reorder(const VarDimArray& order) {
304  CHECK_EQ(DataTypeToEnum<T>::v(), dtype())
305      << "Reorder requested with the wrong datatype";
306  CHECK_EQ(order.size(), dims_) << "Order length must be SparseTensor rank";
307  auto ix_t = ix_.matrix<int64>();
308  auto vals_t = vals_.vec<T>();
309
310  std::vector<int64> reorder(num_entries());
311  std::iota(reorder.begin(), reorder.end(), 0);
312
313  // Sort to get order of indices
314  switch (order.size()) {
315#define CASE_SORT(ORDER_SIZE)                                    \
316  case ORDER_SIZE: {                                             \
317    FixedDimComparator<ORDER_SIZE> sorter(ix_t, order, shape()); \
318    std::sort(reorder.begin(), reorder.end(), sorter);           \
319    break;                                                       \
320  }
321    CASE_SORT(0);
322    CASE_SORT(1);
323    CASE_SORT(2);
324    CASE_SORT(3);
325    CASE_SORT(4);
326    CASE_SORT(5);
327#undef CASE_SORT
328    default: {
329      DimComparator sorter(ix_t, order, shape());
330      std::sort(reorder.begin(), reorder.end(), sorter);
331    }
332  }
333
334  // We have a forward reordering, but what we'll need is a
335  // permutation (the inverse).  This can be calculated with O(1)
336  // additional
337  // and O(n) time (INVPERM) but we just do the simple thing here.
338  std::vector<size_t> permutation(reorder.size());
339  for (std::size_t n = 0; n < reorder.size(); ++n) {
340    permutation[reorder[n]] = n;
341  }
342
343  // Update indices & values by converting the permutations to
344  // a product of transpositions.  Iterate over the cycles in the
345  // permutation, and convert each of those into a product of
346  // transpositions (swaps):
347  //   https://en.wikipedia.org/wiki/Cyclic_permutation
348  // This is N swaps, 2*N comparisons.
349  for (std::size_t n = 0; n + 1 < permutation.size(); ++n) {
350    while (n != permutation[n]) {
351      std::size_t r = permutation[n];
352      std::swap_ranges(&(ix_t(n, 0)), &(ix_t(n + 1, 0)), &(ix_t(r, 0)));
353      std::swap(vals_t(n), vals_t(r));
354      std::swap(permutation[n], permutation[r]);
355    }
356  }
357
358  order_ = ShapeArray(order.begin(), order.end());
359}
360
361template <typename T>
362bool SparseTensor::ValidateAndInitializeToDense(Tensor* out, bool initialize) {
363  CHECK_EQ(DataTypeToEnum<T>::v(), dtype())
364      << "ToDense requested with the wrong datatype";
365
366  CHECK_EQ(out->shape().dims(), dims_)
367      << "Incompatible dimensions between SparseTensor and output";
368
369  CHECK_EQ(out->dtype(), DataTypeToEnum<T>::v())
370      << "Output must be type: " << DataTypeToEnum<T>::v()
371      << " but got: " << out->dtype();
372
373  // Make sure the dense output is the same rank and has room
374  // to hold the SparseTensor.
375  const auto& out_shape = out->shape();
376  if (shape_.size() != out_shape.dims()) return false;
377  for (int d = 0; d < shape_.size(); ++d) {
378    if (shape_[d] > out_shape.dim_size(d)) return false;
379  }
380
381  if (initialize) {
382    auto out_t = out->flat<T>();
383    out_t.setConstant(T());
384  }
385
386  return true;
387}
388
389template <typename T>
390bool SparseTensor::ToDense(Tensor* out, bool initialize) {
391  if (!ValidateAndInitializeToDense<T>(out, initialize)) return false;
392
393  auto out_t = out->flat<T>();
394  auto ix_t = ix_.matrix<int64>();
395  auto vals_t = vals_.vec<T>();
396
397  std::vector<int64> strides(dims_);
398  const auto& out_shape = out->shape();
399  if (dims_ > 0) {
400    strides[dims_ - 1] = 1;
401  }
402  for (int d = dims_ - 2; d >= 0; --d) {
403    strides[d] = strides[d + 1] * out_shape.dim_size(d + 1);
404  }
405
406  for (int n = 0; n < vals_t.dimension(0); ++n) {
407    bool invalid_dims = false;
408    int64 ix = 0;
409    for (int d = 0; d < dims_; ++d) {
410      const int64 ix_n_d = internal::SubtleMustCopy(ix_t(n, d));
411      if (!FastBoundsCheck(ix_n_d, out_shape.dim_size(d))) {
412        invalid_dims = true;
413      }
414      ix += strides[d] * ix_n_d;
415    }
416    if (invalid_dims) return false;
417    out_t(ix) = vals_t(n);
418  }
419  return true;
420}
421
422template <typename T>
423SparseTensor SparseTensor::Concat(
424    const gtl::ArraySlice<SparseTensor>& tensors) {
425  CHECK_GE(tensors.size(), size_t{1}) << "Cannot concat 0 SparseTensors";
426  const int dims = tensors[0].dims_;
427  CHECK_GE(dims, 1) << "Cannot concat 0-dimensional SparseTensors";
428  auto order_0 = tensors[0].order();
429  const int primary_dim = order_0[0];
430  ShapeArray final_order(order_0.begin(), order_0.end());
431  ShapeArray final_shape(tensors[0].shape().begin(), tensors[0].shape().end());
432  final_shape[primary_dim] = 0;  // We'll build this up as we go along.
433  int num_entries = 0;
434
435  bool fully_ordered = true;
436  for (const SparseTensor& st : tensors) {
437    CHECK_EQ(st.dims_, dims) << "All SparseTensors must have the same rank.";
438    CHECK_EQ(DataTypeToEnum<T>::v(), st.dtype())
439        << "Concat requested with the wrong data type";
440    CHECK_GE(st.order()[0], 0) << "SparseTensor must be ordered";
441    CHECK_EQ(st.order()[0], primary_dim)
442        << "All SparseTensors' order[0] must match.  This is the concat dim.";
443    if (st.order() != final_order) fully_ordered = false;
444    const VarDimArray& st_shape = st.shape();
445    for (int d = 0; d < dims - 1; ++d) {
446      const int cdim = (d < primary_dim) ? d : d + 1;
447      CHECK_EQ(final_shape[cdim], st_shape[cdim])
448          << "All SparseTensors' shapes must match except on the concat dim.  "
449          << "Concat dim: " << primary_dim
450          << ", mismatched shape at dim: " << cdim
451          << ".  Expecting shape like: [" << str_util::Join(final_shape, ",")
452          << "] but saw shape: [" << str_util::Join(st_shape, ",") << "]";
453    }
454
455    // Update dimension of final shape
456    final_shape[primary_dim] =
457        (final_shape[primary_dim] + st_shape[primary_dim]);
458
459    num_entries += st.num_entries();  // Update number of entries
460  }
461
462  // If nonconsistent ordering among inputs, set final order to -1s.
463  if (!fully_ordered) {
464    final_order = UndefinedOrder(final_shape);
465  }
466
467  Tensor output_ix(DT_INT64, TensorShape({num_entries, dims}));
468  Tensor output_vals(DataTypeToEnum<T>::v(), TensorShape({num_entries}));
469
470  TTypes<int64>::Matrix ix_t = output_ix.matrix<int64>();
471  typename TTypes<T>::Vec vals_t = output_vals.vec<T>();
472
473  Eigen::DenseIndex offset = 0;
474  int64 shape_offset = 0;
475  for (const SparseTensor& st : tensors) {
476    const int st_num_entries = st.num_entries();
477
478    // Fill in indices & values.
479    std::copy_n(&st.vals_.vec<T>()(0), st_num_entries, &vals_t(offset));
480
481    const auto* st_ix = &st.ix_.matrix<int64>()(0, 0);
482    auto* ix_out = &ix_t(offset, 0);
483    for (std::size_t i = 0; i < st_num_entries * dims; ++i) {
484      *ix_out++ = *st_ix++ + ((i % dims == primary_dim) ? shape_offset : 0);
485    }
486
487    offset += st_num_entries;
488    shape_offset += st.shape()[primary_dim];
489  }
490
491  return SparseTensor(output_ix, output_vals, final_shape, final_order);
492}
493
494template <typename T>
495std::vector<SparseTensor> SparseTensor::Split(const SparseTensor& input_tensor,
496                                              const int split_dim,
497                                              const int num_split) {
498  std::vector<Tensor> output_indices;
499  std::vector<Tensor> output_values;
500  std::vector<TensorShape> output_shapes;
501  output_indices.reserve(num_split);
502  output_values.reserve(num_split);
503  output_shapes.reserve(num_split);
504
505  std::vector<typename TTypes<int64>::Matrix> output_indices_t;
506  std::vector<typename TTypes<T>::Vec> output_values_t;
507  output_indices_t.reserve(num_split);
508  output_values_t.reserve(num_split);
509  auto input_values_t = input_tensor.values().vec<T>();
510  auto input_indices_t = input_tensor.indices().matrix<int64>();
511
512  std::vector<int> num_values(num_split, 0);
513  const int num_dim = input_tensor.shape().size();
514  const int split_dim_size = input_tensor.shape()[split_dim];
515  const int split_size = split_dim_size / num_split;
516
517  CHECK(num_split > 0 && num_split <= split_dim_size) << "num_split must be in "
518                                                         "the interval (0, "
519                                                      << split_dim_size << "]";
520  CHECK(split_dim >= 0 && split_dim < num_dim) << "num_dim must be in "
521                                                  "the interval [0, "
522                                               << num_dim << ")";
523
524  const int residual = split_dim_size % num_split;
525  for (int i = 0; i < input_tensor.indices().dim_size(0); ++i) {
526    const int dim = input_tensor.indices().matrix<int64>()(i, split_dim);
527    int slice_index = GetSliceIndex(dim, split_size, residual);
528    num_values[slice_index]++;
529  }
530
531  for (int i = 0; i < num_split; ++i) {
532    // TODO(ataei): Pass an allocator to avoid allocating large memory buffer.
533    output_indices.emplace_back(DT_INT64,
534                                TensorShape({num_values[i], num_dim}));
535    output_values.emplace_back(DataTypeToEnum<T>::v(),
536                               TensorShape({num_values[i]}));
537    output_shapes.emplace_back(input_tensor.shape());
538    output_indices_t.emplace_back(output_indices[i].matrix<int64>());
539    output_values_t.emplace_back(output_values[i].vec<T>());
540    const int size = GetSliceShape(i, split_size, residual);
541    output_shapes[i].set_dim(split_dim, size);
542  }
543
544  std::vector<int> values_inserted_in_slice(num_split, 0);
545  for (int i = 0; i < input_tensor.indices().dim_size(0); ++i) {
546    const int dim = input_indices_t(i, split_dim);
547    const int slice_index = GetSliceIndex(dim, split_size, residual);
548    const int slice_dim = values_inserted_in_slice[slice_index]++;
549    output_values_t[slice_index](slice_dim) = input_values_t(i);
550    for (int j = 0; j < num_dim; ++j) {
551      const int64 original_dim = input_indices_t(i, j);
552      output_indices_t[slice_index](slice_dim, j) =
553          (j == split_dim)
554              ? GetDimensionInSlice(original_dim, split_size, residual)
555              : original_dim;
556    }
557  }
558
559  std::vector<SparseTensor> output_tensors;
560  output_tensors.reserve(num_split);
561  for (int i = 0; i < num_split; ++i) {
562    output_tensors.emplace_back(output_indices[i], output_values[i],
563                                output_shapes[i]);
564  }
565  return output_tensors;
566}
567
568template <typename T>
569SparseTensor SparseTensor::Slice(const SparseTensor& input_tensor,
570                                 const gtl::ArraySlice<int64>& start,
571                                 const gtl::ArraySlice<int64>& size) {
572  TensorShape output_shape(input_tensor.shape());
573
574  const int dims = input_tensor.dims();
575  for (int dim = 0; dim < dims; dim++) {
576    int64 dim_size = start[dim] + size[dim] < output_shape.dim_size(dim)
577                         ? size[dim]
578                         : output_shape.dim_size(dim) - start[dim];
579    output_shape.set_dim(dim, dim_size);
580  }
581
582  auto input_indices_t = input_tensor.indices().matrix<int64>();
583  auto input_values_t = input_tensor.values().vec<T>();
584
585  // Find the number of indices that fall inside start and size.
586  int count = 0;
587  for (int i = 0; i < input_tensor.indices().dim_size(0); i++) {
588    // The following will check to see if an input is within the
589    // range specified by start and size.
590    // The for loop below iterates through all dimensions. In case
591    // the index falls outside of the start and size at any dimension,
592    // it will be considered as a "no hit" (hit = false). In this
593    // case, it will not be counted as the index that fall inside
594    // the range specified by start and size.
595    bool hit = true;
596    for (int dim = 0; dim < dims; dim++) {
597      if (!(start[dim] <= input_indices_t(i, dim) &&
598            input_indices_t(i, dim) < start[dim] + size[dim])) {
599        hit = false;
600        break;
601      }
602    }
603    if (!hit) {
604      continue;
605    }
606    count++;
607  }
608
609  Tensor output_values(DataTypeToEnum<T>::v(), TensorShape({count}));
610  Tensor output_indices(DT_INT64, TensorShape({count, dims}));
611
612  auto output_values_t = output_values.vec<T>();
613  auto output_indices_t = output_indices.matrix<int64>();
614
615  // Obtain the output indices that fall inside start and size.
616  int index = 0;
617  for (int i = 0; i < input_tensor.indices().dim_size(0) && index < count;
618       i++) {
619    // The logic here is similar as the above except that the above
620    // only count the number of indices while here we actually generate
621    // the output.
622    bool hit = true;
623    for (int dim = 0; dim < dims; dim++) {
624      if (!(start[dim] <= input_indices_t(i, dim) &&
625            input_indices_t(i, dim) < start[dim] + size[dim])) {
626        hit = false;
627        break;
628      }
629    }
630    if (!hit) {
631      continue;
632    }
633    output_values_t(index) = input_values_t(i);
634    for (int dim = 0; dim < dims; dim++) {
635      output_indices_t(index, dim) = input_indices_t(i, dim) - start[dim];
636    }
637    index++;
638  }
639
640  return SparseTensor(output_indices, output_values, output_shape);
641}
642
643}  // namespace sparse
644}  // namespace tensorflow
645
646#endif  // TENSORFLOW_UTIL_SPARSE_SPARSE_TENSOR_H_
647