121d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka/*
221d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka * Copyright (C) 2017 The Android Open Source Project
321d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka *
421d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka * Licensed under the Apache License, Version 2.0 (the "License");
521d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka * you may not use this file except in compliance with the License.
621d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka * You may obtain a copy of the License at
721d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka *
821d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka *      http://www.apache.org/licenses/LICENSE-2.0
921d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka *
1021d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka * Unless required by applicable law or agreed to in writing, software
1121d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka * distributed under the License is distributed on an "AS IS" BASIS,
1221d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1321d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka * See the License for the specific language governing permissions and
1421d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka * limitations under the License.
1521d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka */
1621d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka
17b23e2125be90bbf6124e9cd5684fc93026c5ec4dLukas Zilka#ifndef LIBTEXTCLASSIFIER_TENSOR_VIEW_H_
18b23e2125be90bbf6124e9cd5684fc93026c5ec4dLukas Zilka#define LIBTEXTCLASSIFIER_TENSOR_VIEW_H_
1921d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka
2021d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka#include <algorithm>
2121d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka#include <vector>
2221d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka
2321d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilkanamespace libtextclassifier2 {
2421d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilkanamespace internal {
2521d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka// Computes the number of elements in a tensor of given shape.
2621d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilkaint NumberOfElements(const std::vector<int>& shape);
2721d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka}  // namespace internal
2821d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka
2921d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka// View of a tensor of given type.
3021d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka// NOTE: Does not own the underlying memory, so the contract about its validity
3121d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka// needs to be specified on the interface that returns it.
3221d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilkatemplate <typename T>
3321d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilkaclass TensorView {
3421d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka public:
3521d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka  TensorView(const T* data, const std::vector<int>& shape)
3621d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka      : data_(data), shape_(shape), size_(internal::NumberOfElements(shape)) {}
3721d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka
3821d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka  static TensorView Invalid() {
3921d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka    static std::vector<int>& invalid_shape =
4021d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka        *[]() { return new std::vector<int>(0); }();
4121d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka    return TensorView(nullptr, invalid_shape);
4221d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka  }
4321d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka
4421d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka  bool is_valid() const { return data_ != nullptr; }
4521d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka
4621d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka  const std::vector<int>& shape() const { return shape_; }
4721d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka
4821d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka  int dim(int i) const { return shape_[i]; }
4921d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka
5021d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka  int dims() const { return shape_.size(); }
5121d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka
5221d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka  const T* data() const { return data_; }
5321d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka
5421d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka  int size() const { return size_; }
5521d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka
5621d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka  bool copy_to(T* dest, int dest_size) const {
5721d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka    if (dest_size < size_) {
5821d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka      return false;
5921d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka    }
6021d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka    std::copy(data_, data_ + size_, dest);
6121d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka    return true;
6221d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka  }
6321d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka
6421d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka private:
6521d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka  const T* data_ = nullptr;
6621d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka  const std::vector<int> shape_;
6721d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka  const int size_;
6821d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka};
6921d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka
7021d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka}  // namespace libtextclassifier2
7121d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka
72b23e2125be90bbf6124e9cd5684fc93026c5ec4dLukas Zilka#endif  // LIBTEXTCLASSIFIER_TENSOR_VIEW_H_
73