1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/contrib/lite/toco/tflite/types.h"
16
17namespace toco {
18
19namespace tflite {
20
21namespace {
22template <ArrayDataType T>
23DataBuffer::FlatBufferOffset CopyBuffer(
24    const Array& array, flatbuffers::FlatBufferBuilder* builder) {
25  using NativeT = ::toco::DataType<T>;
26  const auto& src_data = array.GetBuffer<T>().data;
27  const uint8_t* dst_data = reinterpret_cast<const uint8_t*>(src_data.data());
28  auto size = src_data.size() * sizeof(NativeT);
29  return builder->CreateVector(dst_data, size);
30}
31
32template <ArrayDataType T>
33void CopyBuffer(const ::tflite::Buffer& buffer, Array* array) {
34  using NativeT = ::toco::DataType<T>;
35  auto* src_buffer = buffer.data();
36  const NativeT* src_data =
37      reinterpret_cast<const NativeT*>(src_buffer->data());
38  int num_items = src_buffer->size() / sizeof(NativeT);
39
40  std::vector<NativeT>* dst_data = &array->GetMutableBuffer<T>().data;
41  for (int i = 0; i < num_items; ++i) {
42    dst_data->push_back(*src_data);
43    ++src_data;
44  }
45}
46}  // namespace
47
48::tflite::TensorType DataType::Serialize(ArrayDataType array_data_type) {
49  switch (array_data_type) {
50    case ArrayDataType::kFloat:
51      return ::tflite::TensorType_FLOAT32;
52    case ArrayDataType::kInt32:
53      return ::tflite::TensorType_INT32;
54    case ArrayDataType::kInt64:
55      return ::tflite::TensorType_INT64;
56    case ArrayDataType::kUint8:
57      return ::tflite::TensorType_UINT8;
58    case ArrayDataType::kString:
59      return ::tflite::TensorType_STRING;
60    default:
61      // FLOAT32 is filled for unknown data types.
62      // TODO(ycling): Implement type inference in TF Lite interpreter.
63      return ::tflite::TensorType_FLOAT32;
64  }
65}
66
67ArrayDataType DataType::Deserialize(int tensor_type) {
68  switch (::tflite::TensorType(tensor_type)) {
69    case ::tflite::TensorType_FLOAT32:
70      return ArrayDataType::kFloat;
71    case ::tflite::TensorType_INT32:
72      return ArrayDataType::kInt32;
73    case ::tflite::TensorType_INT64:
74      return ArrayDataType::kInt64;
75    case ::tflite::TensorType_STRING:
76      return ArrayDataType::kString;
77    case ::tflite::TensorType_UINT8:
78      return ArrayDataType::kUint8;
79    default:
80      LOG(FATAL) << "Unhandled tensor type '" << tensor_type << "'.";
81  }
82}
83
84flatbuffers::Offset<flatbuffers::Vector<uint8_t>> DataBuffer::Serialize(
85    const Array& array, flatbuffers::FlatBufferBuilder* builder) {
86  if (!array.buffer) return 0;  // an empty buffer, usually an output.
87
88  switch (array.data_type) {
89    case ArrayDataType::kFloat:
90      return CopyBuffer<ArrayDataType::kFloat>(array, builder);
91    case ArrayDataType::kInt32:
92      return CopyBuffer<ArrayDataType::kInt32>(array, builder);
93    case ArrayDataType::kString:
94      return CopyBuffer<ArrayDataType::kString>(array, builder);
95    case ArrayDataType::kUint8:
96      return CopyBuffer<ArrayDataType::kUint8>(array, builder);
97    default:
98      LOG(FATAL) << "Unhandled array data type.";
99  }
100}
101
102void DataBuffer::Deserialize(const ::tflite::Tensor& tensor,
103                             const ::tflite::Buffer& buffer, Array* array) {
104  if (tensor.buffer() == 0) return;      // an empty buffer, usually an output.
105  if (buffer.data() == nullptr) return;  // a non-defined buffer.
106
107  switch (tensor.type()) {
108    case ::tflite::TensorType_FLOAT32:
109      return CopyBuffer<ArrayDataType::kFloat>(buffer, array);
110    case ::tflite::TensorType_INT32:
111      return CopyBuffer<ArrayDataType::kInt32>(buffer, array);
112    case ::tflite::TensorType_INT64:
113      return CopyBuffer<ArrayDataType::kInt64>(buffer, array);
114    case ::tflite::TensorType_STRING:
115      return CopyBuffer<ArrayDataType::kString>(buffer, array);
116    case ::tflite::TensorType_UINT8:
117      return CopyBuffer<ArrayDataType::kUint8>(buffer, array);
118    default:
119      LOG(FATAL) << "Unhandled tensor type.";
120  }
121}
122
123::tflite::Padding Padding::Serialize(PaddingType padding_type) {
124  switch (padding_type) {
125    case PaddingType::kSame:
126      return ::tflite::Padding_SAME;
127    case PaddingType::kValid:
128      return ::tflite::Padding_VALID;
129    default:
130      LOG(FATAL) << "Unhandled padding type.";
131  }
132}
133
134PaddingType Padding::Deserialize(int padding) {
135  switch (::tflite::Padding(padding)) {
136    case ::tflite::Padding_SAME:
137      return PaddingType::kSame;
138    case ::tflite::Padding_VALID:
139      return PaddingType::kValid;
140    default:
141      LOG(FATAL) << "Unhandled padding.";
142  }
143}
144
145::tflite::ActivationFunctionType ActivationFunction::Serialize(
146    FusedActivationFunctionType faf_type) {
147  switch (faf_type) {
148    case FusedActivationFunctionType::kNone:
149      return ::tflite::ActivationFunctionType_NONE;
150    case FusedActivationFunctionType::kRelu:
151      return ::tflite::ActivationFunctionType_RELU;
152    case FusedActivationFunctionType::kRelu6:
153      return ::tflite::ActivationFunctionType_RELU6;
154    case FusedActivationFunctionType::kRelu1:
155      return ::tflite::ActivationFunctionType_RELU_N1_TO_1;
156    default:
157      LOG(FATAL) << "Unhandled fused activation function type.";
158  }
159}
160
161FusedActivationFunctionType ActivationFunction::Deserialize(
162    int activation_function) {
163  switch (::tflite::ActivationFunctionType(activation_function)) {
164    case ::tflite::ActivationFunctionType_NONE:
165      return FusedActivationFunctionType::kNone;
166    case ::tflite::ActivationFunctionType_RELU:
167      return FusedActivationFunctionType::kRelu;
168    case ::tflite::ActivationFunctionType_RELU6:
169      return FusedActivationFunctionType::kRelu6;
170    case ::tflite::ActivationFunctionType_RELU_N1_TO_1:
171      return FusedActivationFunctionType::kRelu1;
172    default:
173      LOG(FATAL) << "Unhandled fused activation function type.";
174  }
175}
176
177}  // namespace tflite
178
179}  // namespace toco
180