1/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#ifndef LIBTEXTCLASSIFIER_TENSOR_VIEW_H_
18#define LIBTEXTCLASSIFIER_TENSOR_VIEW_H_
19
20#include <algorithm>
21#include <vector>
22
23namespace libtextclassifier2 {
24namespace internal {
25// Computes the number of elements in a tensor of given shape.
26int NumberOfElements(const std::vector<int>& shape);
27}  // namespace internal
28
29// View of a tensor of given type.
30// NOTE: Does not own the underlying memory, so the contract about its validity
31// needs to be specified on the interface that returns it.
32template <typename T>
33class TensorView {
34 public:
35  TensorView(const T* data, const std::vector<int>& shape)
36      : data_(data), shape_(shape), size_(internal::NumberOfElements(shape)) {}
37
38  static TensorView Invalid() {
39    static std::vector<int>& invalid_shape =
40        *[]() { return new std::vector<int>(0); }();
41    return TensorView(nullptr, invalid_shape);
42  }
43
44  bool is_valid() const { return data_ != nullptr; }
45
46  const std::vector<int>& shape() const { return shape_; }
47
48  int dim(int i) const { return shape_[i]; }
49
50  int dims() const { return shape_.size(); }
51
52  const T* data() const { return data_; }
53
54  int size() const { return size_; }
55
56  bool copy_to(T* dest, int dest_size) const {
57    if (dest_size < size_) {
58      return false;
59    }
60    std::copy(data_, data_ + size_, dest);
61    return true;
62  }
63
64 private:
65  const T* data_ = nullptr;
66  const std::vector<int> shape_;
67  const int size_;
68};
69
70}  // namespace libtextclassifier2
71
72#endif  // LIBTEXTCLASSIFIER_TENSOR_VIEW_H_
73