1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/contrib/lite/kernels/test_util.h"
16
17#include "tensorflow/contrib/lite/version.h"
18#include "tensorflow/core/platform/logging.h"
19
20namespace tflite {
21
22using ::testing::FloatNear;
23using ::testing::Matcher;
24
25namespace {
26template <typename T>
27std::pair<float, int32_t> QuantizationParams(float f_min, float f_max) {
28  // These are required by many quantized operations.
29  CHECK_LE(f_min, 0);
30  CHECK_GE(f_max, 0);
31  T q_min = std::numeric_limits<T>::min();
32  T q_max = std::numeric_limits<T>::max();
33  float range = q_max - q_min;
34  float scale = (f_max - f_min) / range;
35  int32_t zero_point = std::min(
36      q_max,
37      std::max(q_min, static_cast<T>(std::round(q_min - f_min / scale))));
38  return {scale, zero_point};
39}
40}  // namespace
41
42std::vector<Matcher<float>> ArrayFloatNear(const std::vector<float>& values,
43                                           float max_abs_error) {
44  std::vector<Matcher<float>> matchers;
45  matchers.reserve(values.size());
46  for (const float& v : values) {
47    matchers.emplace_back(FloatNear(v, max_abs_error));
48  }
49  return matchers;
50}
51
52int SingleOpModel::AddTensor(TensorData t, std::initializer_list<int> data) {
53  int id = tensors_.size();
54
55  // This is slightly different depending on whether we are adding a
56  // quantized or a regular tensor.
57  bool is_quantized = (t.min != 0 || t.max != 0 || t.scale != 0);
58
59  flatbuffers::Offset<QuantizationParameters> q_params = 0;
60
61  if (is_quantized) {
62    if (t.min != 0 || t.max != 0) {
63      if (t.type == TensorType_UINT8) {
64        std::tie(t.scale, t.zero_point) =
65            QuantizationParams<uint8_t>(t.min, t.max);
66      } else if (t.type == TensorType_INT32) {
67        std::tie(t.scale, t.zero_point) =
68            QuantizationParams<int32_t>(t.min, t.max);
69      } else {
70        LOG(FATAL) << "No support for the requested quantized type";
71      }
72      t.min = 0;
73      t.max = 0;
74    }
75
76    q_params = CreateQuantizationParameters(
77        builder_, /*min=*/0, /*max=*/0, builder_.CreateVector<float>({t.scale}),
78        builder_.CreateVector<int64_t>({t.zero_point}));
79  }
80
81  int buffer_id = 0;
82  if (data.size()) {
83    // Initialize buffers list with empty buffer to allow for non-const tensors.
84    if (buffers_.empty()) {
85      buffers_.push_back(CreateBuffer(builder_, builder_.CreateVector({})));
86    }
87
88    // Add data as a Buffer to buffers list.
89    buffer_id = buffers_.size();
90    auto data_buffer =
91        builder_.CreateVector(reinterpret_cast<const uint8_t*>(data.begin()),
92                              sizeof(int) * data.size());
93    buffers_.push_back(CreateBuffer(builder_, data_buffer));
94  }
95
96  tensors_.push_back(CreateTensor(builder_, builder_.CreateVector<int>(t.shape),
97                                  t.type, /*buffer=*/buffer_id,
98                                  /*name=*/0, q_params));
99
100  tensor_data_[id] = t;
101
102  return id;
103}
104
105int SingleOpModel::AddInput(const TensorData& t) {
106  int id = AddTensor(t, {});
107  inputs_.push_back(id);
108  return id;
109}
110
111int SingleOpModel::AddConstInput(TensorType type,
112                                 std::initializer_list<int> data,
113                                 std::initializer_list<int> shape) {
114  int id = AddTensor(TensorData{type, shape}, data);
115  inputs_.push_back(id);
116  return id;
117}
118
119int SingleOpModel::AddNullInput() {
120  int id = kOptionalTensor;
121  inputs_.push_back(id);
122  return id;
123}
124
125int SingleOpModel::AddOutput(const TensorData& t) {
126  int id = AddTensor(t, {});
127  outputs_.push_back(id);
128  return id;
129}
130
131void SingleOpModel::SetBuiltinOp(BuiltinOperator type,
132                                 BuiltinOptions builtin_options_type,
133                                 flatbuffers::Offset<void> builtin_options) {
134  opcodes_.push_back(CreateOperatorCode(builder_, type, 0));
135  operators_.push_back(CreateOperator(
136      builder_, /*opcode_index=*/0, builder_.CreateVector<int32_t>(inputs_),
137      builder_.CreateVector<int32_t>(outputs_), builtin_options_type,
138      builtin_options,
139      /*custom_options=*/0, CustomOptionsFormat_FLEXBUFFERS));
140}
141
142void SingleOpModel::SetCustomOp(
143    const string& name, const std::vector<uint8_t>& custom_option,
144    const std::function<TfLiteRegistration*()>& registeration) {
145  custom_registrations_[name] = registeration;
146  opcodes_.push_back(
147      CreateOperatorCodeDirect(builder_, BuiltinOperator_CUSTOM, name.data()));
148  operators_.push_back(CreateOperator(
149      builder_, /*opcode_index=*/0, builder_.CreateVector<int32_t>(inputs_),
150      builder_.CreateVector<int32_t>(outputs_), BuiltinOptions_NONE, 0,
151      builder_.CreateVector<uint8_t>(custom_option),
152      CustomOptionsFormat_FLEXBUFFERS));
153}
154
155void SingleOpModel::BuildInterpreter(
156    std::vector<std::vector<int>> input_shapes) {
157  auto opcodes = builder_.CreateVector(opcodes_);
158  auto operators = builder_.CreateVector(operators_);
159  auto tensors = builder_.CreateVector(tensors_);
160  auto inputs = builder_.CreateVector<int32_t>(inputs_);
161  auto outputs = builder_.CreateVector<int32_t>(outputs_);
162  // Create a single subgraph
163  std::vector<flatbuffers::Offset<SubGraph>> subgraphs;
164  auto subgraph = CreateSubGraph(builder_, tensors, inputs, outputs, operators);
165  subgraphs.push_back(subgraph);
166  auto subgraphs_flatbuffer = builder_.CreateVector(subgraphs);
167
168  auto buffers = builder_.CreateVector(buffers_);
169  auto description = builder_.CreateString("programmatic model");
170  builder_.Finish(CreateModel(builder_, TFLITE_SCHEMA_VERSION, opcodes,
171                              subgraphs_flatbuffer, description, buffers));
172
173  auto* model = GetModel(builder_.GetBufferPointer());
174
175  if (!resolver_) {
176    auto resolver = new ops::builtin::BuiltinOpResolver();
177    for (const auto& reg : custom_registrations_) {
178      resolver->AddCustom(reg.first.data(), reg.second());
179    }
180    resolver_ = std::unique_ptr<OpResolver>(resolver);
181  }
182  InterpreterBuilder(model, *resolver_)(&interpreter_);
183
184  CHECK(interpreter_ != nullptr);
185
186  int i = 0;
187  for (const auto& shape : input_shapes) {
188    int input_idx = interpreter_->inputs()[i++];
189    if (input_idx == kOptionalTensor) continue;
190    if (shape.empty()) continue;
191    CHECK(interpreter_->ResizeInputTensor(input_idx, shape) == kTfLiteOk);
192  }
193  CHECK(interpreter_->AllocateTensors() == kTfLiteOk)
194      << "Cannot allocate tensors";
195}
196
197void SingleOpModel::Invoke() { CHECK(interpreter_->Invoke() == kTfLiteOk); }
198
199int32_t SingleOpModel::GetTensorSize(int index) const {
200  TfLiteTensor* t = interpreter_->tensor(index);
201  CHECK(t);
202  int total_size = 1;
203  for (int i = 0; i < t->dims->size; ++i) {
204    total_size *= t->dims->data[i];
205  }
206  return total_size;
207}
208
209template <>
210std::vector<string> SingleOpModel::ExtractVector(int index) {
211  TfLiteTensor* tensor_ptr = interpreter_->tensor(index);
212  CHECK(tensor_ptr != nullptr);
213  const int num_strings = GetStringCount(tensor_ptr);
214  std::vector<string> result;
215  result.reserve(num_strings);
216  for (int i = 0; i < num_strings; ++i) {
217    const auto str = GetString(tensor_ptr, i);
218    result.emplace_back(str.str, str.len);
219  }
220  return result;
221}
222}  // namespace tflite
223