1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/framework/tensor_slice.h"
17#include <vector>
18#include "tensorflow/core/lib/core/errors.h"
19#include "tensorflow/core/lib/strings/numbers.h"
20#include "tensorflow/core/lib/strings/str_util.h"
21#include "tensorflow/core/lib/strings/strcat.h"
22#include "tensorflow/core/platform/logging.h"
23
24namespace tensorflow {
25
26TensorSlice::TensorSlice(int dim) { SetFullSlice(dim); }
27
28TensorSlice::TensorSlice(const TensorSliceProto& proto) {
29  starts_.reserve(proto.extent_size());
30  lengths_.reserve(proto.extent_size());
31  for (const auto& e : proto.extent()) {
32    starts_.push_back(e.start());
33    lengths_.push_back(GetExtentLength(e));
34  }
35}
36
37TensorSlice::TensorSlice(
38    std::initializer_list<std::pair<int64, int64>> extents) {
39  starts_.reserve(extents.size());
40  lengths_.reserve(extents.size());
41  for (const auto& e : extents) {
42    starts_.push_back(e.first);
43    lengths_.push_back(e.second);
44  }
45}
46
47Status TensorSlice::Parse(const string& str, TensorSlice* slice) {
48  std::vector<string> items = str_util::Split(str, ':', str_util::SkipEmpty());
49  slice->starts_.reserve(items.size());
50  slice->lengths_.reserve(items.size());
51  for (const string& x : items) {
52    int64 s, l;
53    if (x == "-") {
54      // "everything"
55      s = 0;
56      l = kFullExtent;
57    } else {
58      std::vector<string> sl = str_util::Split(x, ',', str_util::SkipEmpty());
59      if (sl.size() != 2 || !strings::safe_strto64(sl[0], &s) ||
60          !strings::safe_strto64(sl[1], &l)) {
61        return errors::InvalidArgument(
62            "Expected a pair of numbers or '-' "
63            "but got '",
64            x, "': string = ", str);
65      }
66      if (s < 0 || l <= 0) {
67        return errors::InvalidArgument(
68            "Expected non-negative start and "
69            "positive length but got start = ",
70            s, ", length = ", l, ": string = ", str);
71      }
72    }
73    slice->starts_.push_back(s);
74    slice->lengths_.push_back(l);
75  }
76
77  return Status::OK();
78}
79
80void TensorSlice::Clear() {
81  starts_.clear();
82  lengths_.clear();
83}
84
85bool TensorSlice::IsFull() const {
86  for (int d = 0; d < dims(); ++d) {
87    if (!IsFullAt(d)) return false;
88  }
89  return true;
90}
91
92void TensorSlice::SetFullSlice(int dim) {
93  Clear();
94  starts_.reserve(dim);
95  lengths_.reserve(dim);
96  for (int d = 0; d < dim; ++d) {
97    starts_.push_back(0);
98    lengths_.push_back(kFullExtent);
99  }
100}
101
102void TensorSlice::Extend(int dim) {
103  int old_dim = dims();
104  DCHECK_LE(old_dim, dim);
105  starts_.resize(dim);
106  lengths_.resize(dim);
107  for (int d = old_dim; d < dim; ++d) {
108    starts_[d] = 0;
109    lengths_[d] = kFullExtent;
110  }
111}
112
113void TensorSlice::AsProto(TensorSliceProto* proto) const {
114  for (int d = 0; d < dims(); ++d) {
115    TensorSliceProto::Extent* e = proto->add_extent();
116    // We only need to record the explicit slice for non-full slices
117    if (!IsFullAt(d)) {
118      e->set_start(starts_[d]);
119      e->set_length(lengths_[d]);
120    }
121  }
122}
123
124string TensorSlice::DebugString() const {
125  string buffer;
126  bool first = true;
127  for (int d = 0; d < dims(); ++d) {
128    if (!first) {
129      buffer.append(":");
130    }
131    string s;
132    if (IsFullAt(d)) {
133      buffer.append("-");
134    } else {
135      strings::StrAppend(&buffer, starts_[d], ",", lengths_[d]);
136    }
137    first = false;
138  }
139  return buffer;
140}
141
142bool TensorSlice::Intersect(const TensorSlice& other,
143                            TensorSlice* result) const {
144  // First, if two slices have different ranks, they obviously don't overlap
145  // -- in fact they are not compatible.
146  if (dims() != other.dims()) {
147    return false;
148  }
149
150  // Setting the result to the right dimension
151  if (result) {
152    result->SetFullSlice(dims());
153  }
154  // The two slices overlap if they overlap in all dimensions.
155  for (int d = 0; d < dims(); ++d) {
156    if (IsFullAt(d)) {
157      if (result) {
158        result->set_start(d, other.start(d));
159        result->set_length(d, other.length(d));
160      }
161    } else if (other.IsFullAt(d)) {
162      if (result) {
163        result->set_start(d, start(d));
164        result->set_length(d, length(d));
165      }
166    } else {
167      // If we have an intersection here, it should have a start that is the
168      // max of the two starts and an end that is the min of the two ends.
169      int64 s = std::max(start(d), other.start(d));
170      int64 l = std::min(end(d), other.end(d)) - s;
171      if (l > 0) {
172        // We have a real intersection
173        if (result) {
174          result->set_start(d, s);
175          result->set_length(d, l);
176        }
177      } else {
178        // We don't have an intersection for this dimension -- thus we don't
179        // have any intersection at all.
180        if (result) {
181          result->Clear();
182        }
183        return false;
184      }
185    }
186  }
187  // If we are here, we know there is overlap in every dimension.
188  return true;
189}
190
191bool TensorSlice::operator==(const TensorSlice& other) const {
192  return dims() == other.dims() && starts_ == other.starts_ &&
193         lengths_ == other.lengths_;
194}
195
196void TensorSlice::ComputeRelative(const TensorSlice& sub,
197                                  TensorSlice* relative) const {
198  DCHECK_EQ(dims(), sub.dims());
199  relative->SetFullSlice(dims());
200  for (int d = 0; d < dims(); ++d) {
201    if (IsFullAt(d)) {
202      relative->set_start(d, sub.start(d));
203      relative->set_length(d, sub.length(d));
204    } else {
205      // Otherwise the relative start is the difference between the start of
206      // sub and the start of base
207      relative->set_start(d, sub.start(d) - start(d));
208      relative->set_length(d, sub.length(d));
209    }
210  }
211}
212
213void TensorSlice::UpdateToCover(const TensorSlice& other) {
214  DCHECK_EQ(dims(), other.dims());
215  for (int d = 0; d < dims(); ++d) {
216    if (!IsFullAt(d)) {
217      if (other.IsFullAt(d)) {
218        starts_[d] = 0;
219        lengths_[d] = kFullExtent;
220      } else {
221        const auto new_end = std::max(end(d), other.end(d));
222        set_start(d, std::min(start(d), other.start(d)));
223        set_length(d, new_end - start(d));
224      }
225    }
226  }
227}
228
229// static
230bool TensorSlice::HasExtentLength(const TensorSliceProto::Extent& extent) {
231  return extent.has_length_case() == TensorSliceProto::Extent::kLength;
232}
233
234// static
235int64 TensorSlice::GetExtentLength(const TensorSliceProto::Extent& extent) {
236  if (!HasExtentLength(extent)) return -1;
237  return extent.length();
238}
239
240Status TensorSlice::SliceTensorShape(const TensorShape& shape,
241                                     TensorShape* result_shape) const {
242  result_shape->Clear();
243  // Mismatching ranks: we can't apply the slice at all.
244  if (shape.dims() != dims()) {
245    return errors::Internal("Mismatching ranks: shape = ", shape.DebugString(),
246                            ", slice = ", DebugString());
247  }
248  for (int d = 0; d < dims(); ++d) {
249    if (IsFullAt(d)) {
250      result_shape->AddDim(shape.dim_size(d));
251    } else {
252      // Check if the extent applies to the dimension
253      if (end(d) <= shape.dim_size(d)) {
254        // Yes: the end is within the range of the dim -- we adjust the result
255        // shape so that its size along this dimension is the length of the
256        // slice.
257        result_shape->AddDim(length(d));
258      } else {
259        // The extent doesn't apply to the dimension
260        result_shape->Clear();
261        return errors::Internal("Extent in dimension ", d,
262                                " out of bounds: shape = ", shape.DebugString(),
263                                ", slice = ", DebugString());
264      }
265    }
266  }
267  // If we are here, we have successfully applied the shape.
268  return Status::OK();
269}
270
271const int64 TensorSlice::kFullExtent = -1;
272
273}  // namespace tensorflow
274