121d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka/* 221d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka * Copyright (C) 2017 The Android Open Source Project 321d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka * 421d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka * Licensed under the Apache License, Version 2.0 (the "License"); 521d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka * you may not use this file except in compliance with the License. 621d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka * You may obtain a copy of the License at 721d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka * 821d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka * http://www.apache.org/licenses/LICENSE-2.0 921d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka * 1021d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka * Unless required by applicable law or agreed to in writing, software 1121d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka * distributed under the License is distributed on an "AS IS" BASIS, 1221d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 1321d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka * See the License for the specific language governing permissions and 1421d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka * limitations under the License. 1521d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka */ 1621d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka 17b23e2125be90bbf6124e9cd5684fc93026c5ec4dLukas Zilka#ifndef LIBTEXTCLASSIFIER_TENSOR_VIEW_H_ 18b23e2125be90bbf6124e9cd5684fc93026c5ec4dLukas Zilka#define LIBTEXTCLASSIFIER_TENSOR_VIEW_H_ 1921d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka 2021d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka#include <algorithm> 2121d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka#include <vector> 2221d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka 2321d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilkanamespace libtextclassifier2 { 2421d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilkanamespace internal { 2521d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka// Computes the number of elements in a tensor of given shape. 2621d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilkaint NumberOfElements(const std::vector<int>& shape); 2721d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka} // namespace internal 2821d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka 2921d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka// View of a tensor of given type. 3021d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka// NOTE: Does not own the underlying memory, so the contract about its validity 3121d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka// needs to be specified on the interface that returns it. 3221d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilkatemplate <typename T> 3321d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilkaclass TensorView { 3421d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka public: 3521d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka TensorView(const T* data, const std::vector<int>& shape) 3621d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka : data_(data), shape_(shape), size_(internal::NumberOfElements(shape)) {} 3721d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka 3821d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka static TensorView Invalid() { 3921d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka static std::vector<int>& invalid_shape = 4021d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka *[]() { return new std::vector<int>(0); }(); 4121d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka return TensorView(nullptr, invalid_shape); 4221d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka } 4321d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka 4421d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka bool is_valid() const { return data_ != nullptr; } 4521d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka 4621d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka const std::vector<int>& shape() const { return shape_; } 4721d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka 4821d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka int dim(int i) const { return shape_[i]; } 4921d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka 5021d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka int dims() const { return shape_.size(); } 5121d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka 5221d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka const T* data() const { return data_; } 5321d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka 5421d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka int size() const { return size_; } 5521d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka 5621d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka bool copy_to(T* dest, int dest_size) const { 5721d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka if (dest_size < size_) { 5821d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka return false; 5921d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka } 6021d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka std::copy(data_, data_ + size_, dest); 6121d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka return true; 6221d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka } 6321d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka 6421d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka private: 6521d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka const T* data_ = nullptr; 6621d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka const std::vector<int> shape_; 6721d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka const int size_; 6821d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka}; 6921d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka 7021d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka} // namespace libtextclassifier2 7121d8c98fb12bc83dd0e9f5cb8fa9197ef325e074Lukas Zilka 72b23e2125be90bbf6124e9cd5684fc93026c5ec4dLukas Zilka#endif // LIBTEXTCLASSIFIER_TENSOR_VIEW_H_ 73