1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include <tuple>
16#include <vector>
17
18#include <gmock/gmock.h>
19#include <gtest/gtest.h>
20#include "tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h"
21#include "tensorflow/contrib/lite/toco/model.h"
22#include "tensorflow/contrib/lite/toco/tooling_util.h"
23
24namespace toco {
25
26namespace {
27
28// A gmock matcher that check that elements of a float vector match to a given
29// tolerance.
30std::vector<testing::Matcher<float>> ArrayFloatNear(
31    const std::vector<float>& values, float max_abs_error = 1e-5) {
32  std::vector<testing::Matcher<float>> matchers;
33  matchers.reserve(values.size());
34  for (const float& v : values) {
35    matchers.emplace_back(testing::FloatNear(v, max_abs_error));
36  }
37  return matchers;
38}
39}  // namespace
40
41class CopyArrayDataTest : public ::testing::Test {
42 public:
43  CopyArrayDataTest() {}
44
45  void PrepareBuffers(Model* model, std::initializer_list<float> src_data,
46                      int src_dim_1, int src_dim_2,
47                      std::initializer_list<float> dst_data, int dst_dim_1,
48                      int dst_dim_2) {
49    string src_array = "src_array";
50    src_buffer_ = CreateFloatArrayBuffer(
51        model, &src_array,
52        src_dim_2 == 1 ? Shape({src_dim_1}) : Shape({src_dim_1, src_dim_2}));
53    PopulateBuffer(src_buffer_, src_data);
54    string dst_array = "dst_array";
55    dst_buffer_ = CreateFloatArrayBuffer(
56        model, &dst_array,
57        dst_dim_2 == 1 ? Shape({dst_dim_1}) : Shape({dst_dim_1, dst_dim_2}));
58    PopulateBuffer(dst_buffer_, dst_data);
59  }
60
61  Buffer<ArrayDataType::kFloat>* GetSrcBuffer() { return src_buffer_; }
62  Buffer<ArrayDataType::kFloat>* GetDstBuffer() { return dst_buffer_; }
63
64  void PopulateBuffer(Buffer<ArrayDataType::kFloat>* buffer,
65                      const std::vector<float>& init_data) {
66    for (int i = 0; i < init_data.size(); i++) {
67      buffer->data[i] = init_data[i];
68    }
69  }
70  void UpdateBuffer(Buffer<ArrayDataType::kFloat>* buffer,
71                    std::initializer_list<float> data) {
72    buffer->data.resize(data.size());
73    PopulateBuffer(buffer, data);
74  }
75
76 private:
77  Buffer<ArrayDataType::kFloat>* src_buffer_;
78  Buffer<ArrayDataType::kFloat>* dst_buffer_;
79};
80
81// Copy from 1 big 2D array to 8 smaller ones.
82TEST_F(CopyArrayDataTest, CopyFromBigArrayToSmallerArrayes2D) {
83  // Init src_buffer, dst_buffer.
84  Model model;
85  std::initializer_list<float> large_tf_weight_data = {
86      -0.320407, -0.108683, 0.406358,  -0.410811, -0.285786, -0.15769,
87      -0.194201, 0.170866,  0.084135,  0.201878,  0.21519,   -0.284458,
88      0.495906,  -0.073818, 0.045578,  0.149816,  -0.447073, -0.453578,
89      0.116766,  0.21808,   0.047326,  -0.001985, 0.402193,  0.315517,
90      0.38258,   0.43599,   0.11986,   0.465195,  0.33548,   -0.118789,
91      -0.414159, 0.049269,  0.156108,  0.093459,  -0.129103, -0.086274,
92      0.186188,  -0.324923, 0.4117,    -0.344439, 0.240465,  -0.343331,
93      -0.463082, -0.231706, -0.487465, -0.186592, -0.020756, -0.239007,
94      0.364817,  0.459106,  -0.171447, -0.006542, 0.204032,  -0.375317,
95      -0.041911, 0.051664,  0.320483,  0.155899,  0.156555,  -0.249823,
96      -0.353107, 0.031563,  -0.340771, -0.052532, 0.134631,  -0.257957,
97      -0.50141,  0.486939,  -0.43853,  0.268426,  -0.08754,  -0.109447,
98      -0.502462, -0.028055, -0.121838, -0.046016, 0.105309,  -0.070774,
99      0.495683,  -0.475088, 0.048654,  -0.38582,  0.411018,  -0.315606,
100      0.349628,  0.21698,   0.258989,  -0.097902, 0.331218,  0.034602,
101      0.418069,  -0.089025, -0.417513, 0.07609,   0.393821,  0.404733,
102      -0.055418, -0.43903,  -0.447049, 0.013125,  0.278503,  0.459869,
103      0.143755,  -0.177335, -0.162247, -0.432371, 0.153714,  -0.047403,
104      -0.446775, -0.418363, 0.019743,  0.042025};
105  std::initializer_list<float> tflite_lstm_input_weight = {0, 0, 0, 0, 0, 0,
106                                                           0, 0, 0, 0, 0, 0};
107  PrepareBuffers(&model, large_tf_weight_data, /*src_dim_1=*/16,
108                 /*src_dim_2=*/7, tflite_lstm_input_weight,
109                 /*dst_dim_1=*/4, /*dst_dim_2=*/3);
110
111  // Copy src starts at (0,0), size (4,3).
112  CopyArrayData(*(GetSrcBuffer()),
113                /*src_stride=*/7, /*src_start_idx1=*/0,
114                /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/3,
115                /*dst_start_idx1=*/0, /*dst_start_idx2=*/0,
116                /*dim1_copy_size=*/4, /*dim2_copy_size=*/3);
117  std::vector<float> expected = {-0.320407, -0.108683, 0.406358, 0.170866,
118                                 0.084135,  0.201878,  0.045578, 0.149816,
119                                 -0.447073, -0.001985, 0.402193, 0.315517};
120  EXPECT_THAT(GetDstBuffer()->data, ElementsAreArray(ArrayFloatNear(expected)));
121
122  // Copy src starts at (4,0), size (4,3).
123  CopyArrayData(*(GetSrcBuffer()),
124                /*src_stride=*/7, /*src_start_idx1=*/4,
125                /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/3,
126                /*dst_start_idx1=*/0, /*dst_start_idx2=*/0,
127                /*dim1_copy_size=*/4, /*dim2_copy_size=*/3);
128  expected = {0.33548,   -0.118789, -0.414159, -0.086274, 0.186188,  -0.324923,
129              -0.463082, -0.231706, -0.487465, 0.459106,  -0.171447, -0.006542};
130  EXPECT_THAT(GetDstBuffer()->data, ElementsAreArray(ArrayFloatNear(expected)));
131
132  // Copy src starts at (8,0), size (4,3).
133  CopyArrayData(*(GetSrcBuffer()),
134                /*src_stride=*/7, /*src_start_idx1=*/8,
135                /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/3,
136                /*dst_start_idx1=*/0, /*dst_start_idx2=*/0,
137                /*dim1_copy_size=*/4, /*dim2_copy_size=*/3);
138  expected = {0.320483, 0.155899,  0.156555,  -0.052532, 0.134631, -0.257957,
139              -0.08754, -0.109447, -0.502462, -0.070774, 0.495683, -0.475088};
140  EXPECT_THAT(GetDstBuffer()->data, ElementsAreArray(ArrayFloatNear(expected)));
141
142  // Copy src starts at (12,0), size (4,3).
143  CopyArrayData(*(GetSrcBuffer()),
144                /*src_stride=*/7, /*src_start_idx1=*/12,
145                /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/3,
146                /*dst_start_idx1=*/0, /*dst_start_idx2=*/0,
147                /*dim1_copy_size=*/4, /*dim2_copy_size=*/3);
148  expected = {0.349628,  0.21698,  0.258989, -0.089025, -0.417513, 0.07609,
149              -0.447049, 0.013125, 0.278503, -0.432371, 0.153714,  -0.047403};
150  EXPECT_THAT(GetDstBuffer()->data, ElementsAreArray(ArrayFloatNear(expected)));
151
152  // New dst_buffer with size 16.
153  std::initializer_list<float> tflite_lstm_recurrent_weight = {
154      0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
155  PrepareBuffers(&model, large_tf_weight_data, /*src_dim_1=*/16,
156                 /*src_dim_2=*/7, tflite_lstm_recurrent_weight,
157                 /*dst_dim_1=*/4, /*dst_dim_2=*/4);
158
159  // Copy src starts at (0,3), size (4,4).
160  CopyArrayData(*(GetSrcBuffer()),
161                /*src_stride=*/7, /*src_start_idx1=*/0,
162                /*src_start_idx2=*/3, GetDstBuffer(), /*dst_stride=*/4,
163                /*dst_start_idx1=*/0, /*dst_start_idx2=*/0,
164                /*dim1_copy_size=*/4, /*dim2_copy_size=*/4);
165  expected = {-0.410811, -0.285786, -0.15769,  -0.194201, 0.21519, -0.284458,
166              0.495906,  -0.073818, -0.453578, 0.116766,  0.21808, 0.047326,
167              0.38258,   0.43599,   0.11986,   0.465195};
168  EXPECT_THAT(GetDstBuffer()->data, ElementsAreArray(ArrayFloatNear(expected)));
169
170  // Copy src starts at (4,3), size (4,4).
171  CopyArrayData(*(GetSrcBuffer()),
172                /*src_stride=*/7, /*src_start_idx1=*/4,
173                /*src_start_idx2=*/3, GetDstBuffer(), /*dst_stride=*/4,
174                /*dst_start_idx1=*/0, /*dst_start_idx2=*/0,
175                /*dim1_copy_size=*/4, /*dim2_copy_size=*/4);
176  expected = {0.049269, 0.156108,  0.093459,  -0.129103, 0.4117,    -0.344439,
177              0.240465, -0.343331, -0.186592, -0.020756, -0.239007, 0.364817,
178              0.204032, -0.375317, -0.041911, 0.051664};
179  EXPECT_THAT(GetDstBuffer()->data, ElementsAreArray(ArrayFloatNear(expected)));
180
181  // Copy src starts at (8,3), size (4,4).
182  CopyArrayData(*(GetSrcBuffer()),
183                /*src_stride=*/7, /*src_start_idx1=*/8,
184                /*src_start_idx2=*/3, GetDstBuffer(), /*dst_stride=*/4,
185                /*dst_start_idx1=*/0, /*dst_start_idx2=*/0,
186                /*dim1_copy_size=*/4, /*dim2_copy_size=*/4);
187  expected = {-0.249823, -0.353107, 0.031563,  -0.340771, -0.50141,  0.486939,
188              -0.43853,  0.268426,  -0.028055, -0.121838, -0.046016, 0.105309,
189              0.048654,  -0.38582,  0.411018,  -0.315606};
190  EXPECT_THAT(GetDstBuffer()->data, ElementsAreArray(ArrayFloatNear(expected)));
191
192  // Copy src starts at (12,3), size (4,4).
193  CopyArrayData(*(GetSrcBuffer()),
194                /*src_stride=*/7, /*src_start_idx1=*/12,
195                /*src_start_idx2=*/3, GetDstBuffer(), /*dst_stride=*/4,
196                /*dst_start_idx1=*/0, /*dst_start_idx2=*/0,
197                /*dim1_copy_size=*/4, /*dim2_copy_size=*/4);
198  expected = {-0.097902, 0.331218,  0.034602, 0.418069, 0.393821,  0.404733,
199              -0.055418, -0.43903,  0.459869, 0.143755, -0.177335, -0.162247,
200              -0.446775, -0.418363, 0.019743, 0.042025};
201  EXPECT_THAT(GetDstBuffer()->data, ElementsAreArray(ArrayFloatNear(expected)));
202}
203
204// Copy from 1 big 1D array to 4 small ones.
205TEST_F(CopyArrayDataTest, CopyFromBigArrayToSmallerArrayes1D) {
206  // Init src_buffer, dst_buffer.
207  Model model;
208  std::initializer_list<float> large_tf_bias_data = {
209      0.980304, 0.419808, 0.080278, 0.728548, 0.581674, 0.672433,
210      0.434190, 0.844357, 0.229587, 0.785629, 0.022065, 0.753082,
211      0.422080, 0.539481, 0.878386, 0.168965};
212  std::initializer_list<float> tflite_lstm_i_bias = {0, 0, 0, 0};
213  PrepareBuffers(&model, large_tf_bias_data, /*src_dim_1=*/16,
214                 /*src_dim_2=*/1, tflite_lstm_i_bias,
215                 /*dst_dim_1=*/4, /*dst_dim_2=*/1);
216
217  // Copy starts at (0,), size (4,).
218  CopyArrayData(*(GetSrcBuffer()),
219                /*src_stride=*/1, /*src_start_idx1=*/0,
220                /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/1,
221                /*dst_start_idx1=*/0, /*dst_start_idx2=*/0,
222                /*dim1_copy_size=*/4, /*dim2_copy_size=*/1);
223  std::vector<float> expected = {0.980304, 0.419808, 0.080278, 0.728548};
224  EXPECT_THAT(GetDstBuffer()->data, ElementsAreArray(ArrayFloatNear(expected)));
225
226  // Copy starts at (4,), size (4,).
227  CopyArrayData(*(GetSrcBuffer()),
228                /*src_stride=*/1, /*src_start_idx1=*/4,
229                /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/1,
230                /*dst_start_idx1=*/0, /*dst_start_idx2=*/0,
231                /*dim1_copy_size=*/4, /*dim2_copy_size=*/1);
232  expected = {0.581674, 0.672433, 0.434190, 0.844357};
233  EXPECT_THAT(GetDstBuffer()->data, ElementsAreArray(ArrayFloatNear(expected)));
234
235  // Copy starts at (8,), size (4,).
236  CopyArrayData(*(GetSrcBuffer()),
237                /*src_stride=*/1, /*src_start_idx1=*/8,
238                /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/1,
239                /*dst_start_idx1=*/0, /*dst_start_idx2=*/0,
240                /*dim1_copy_size=*/4, /*dim2_copy_size=*/1);
241  expected = {0.229587, 0.785629, 0.022065, 0.753082};
242  EXPECT_THAT(GetDstBuffer()->data, ElementsAreArray(ArrayFloatNear(expected)));
243
244  // Copy starts at (12,), size (4,).
245  CopyArrayData(*(GetSrcBuffer()),
246                /*src_stride=*/1, /*src_start_idx1=*/12,
247                /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/1,
248                /*dst_start_idx1=*/0, /*dst_start_idx2=*/0,
249                /*dim1_copy_size=*/4, /*dim2_copy_size=*/1);
250  expected = {0.422080, 0.539481, 0.878386, 0.168965};
251  EXPECT_THAT(GetDstBuffer()->data, ElementsAreArray(ArrayFloatNear(expected)));
252}
253
254// Copy from 8 small 2D arrayes to 1 big one.
255TEST_F(CopyArrayDataTest, CopyFromSmallArrayesToBigArray2D) {
256  // Init src_buffer, dst_buffer.
257  Model model;
258  std::initializer_list<float> large_tf_weights_data = {
259      0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
260      0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
261      0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
262      0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
263      0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
264
265  // Copy dst starts (0, 0), size (4, 3).
266  std::initializer_list<float> tflite_lstm_i2i_weight = {
267      -0.320407, -0.108683, 0.406358,  0.170866,  0.084135, 0.201878,
268      0.045578,  0.149816,  -0.447073, -0.001985, 0.402193, 0.315517};
269  PrepareBuffers(&model, tflite_lstm_i2i_weight, /*src_dim_1=*/4,
270                 /*src_dim_2=*/3, large_tf_weights_data,
271                 /*dst_dim_1=*/16, /*dst_dim_2=*/7);
272  CopyArrayData(*(GetSrcBuffer()),
273                /*src_stride=*/3, /*src_start_idx1=*/0,
274                /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/7,
275                /*dst_start_idx1=*/0, /*dst_start_idx2=*/0,
276                /*dim1_copy_size=*/4, /*dim2_copy_size=*/3);
277
278  // Copy dst starts (4, 0), size (4, 3).
279  std::initializer_list<float> tflite_lstm_i2c_weight = {
280      0.33548,   -0.118789, -0.414159, -0.086274, 0.186188,  -0.324923,
281      -0.463082, -0.231706, -0.487465, 0.459106,  -0.171447, -0.006542};
282  PopulateBuffer(GetSrcBuffer(), tflite_lstm_i2c_weight);
283  CopyArrayData(*(GetSrcBuffer()),
284                /*src_stride=*/3, /*src_start_idx1=*/0,
285                /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/7,
286                /*dst_start_idx1=*/4, /*dst_start_idx2=*/0,
287                /*dim1_copy_size=*/4, /*dim2_copy_size=*/3);
288
289  // Copy dst starts (8, 0), size (4, 3).
290  std::initializer_list<float> tflite_lstm_i2f_weight = {
291      0.320483, 0.155899,  0.156555,  -0.052532, 0.134631, -0.257957,
292      -0.08754, -0.109447, -0.502462, -0.070774, 0.495683, -0.475088};
293  PopulateBuffer(GetSrcBuffer(), tflite_lstm_i2f_weight);
294  CopyArrayData(*(GetSrcBuffer()),
295                /*src_stride=*/3, /*src_start_idx1=*/0,
296                /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/7,
297                /*dst_start_idx1=*/8, /*dst_start_idx2=*/0,
298                /*dim1_copy_size=*/4, /*dim2_copy_size=*/3);
299
300  // Copy dst starts (12, 0), size (4, 3).
301  std::initializer_list<float> tflite_lstm_i2o_weight = {
302      0.349628,  0.21698,  0.258989, -0.089025, -0.417513, 0.07609,
303      -0.447049, 0.013125, 0.278503, -0.432371, 0.153714,  -0.047403};
304  PopulateBuffer(GetSrcBuffer(), tflite_lstm_i2o_weight);
305  CopyArrayData(*(GetSrcBuffer()),
306                /*src_stride=*/3, /*src_start_idx1=*/0,
307                /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/7,
308                /*dst_start_idx1=*/12, /*dst_start_idx2=*/0,
309                /*dim1_copy_size=*/4, /*dim2_copy_size=*/3);
310
311  // Copy dst starts (0, 3), size (4, 4).
312  std::initializer_list<float> tflite_lstm_i2r_weight = {
313      -0.410811, -0.285786, -0.15769,  -0.194201, 0.21519, -0.284458,
314      0.495906,  -0.073818, -0.453578, 0.116766,  0.21808, 0.047326,
315      0.38258,   0.43599,   0.11986,   0.465195};
316  UpdateBuffer(GetSrcBuffer(), tflite_lstm_i2r_weight);
317  CopyArrayData(*(GetSrcBuffer()),
318                /*src_stride=*/4, /*src_start_idx1=*/0,
319                /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/7,
320                /*dst_start_idx1=*/0, /*dst_start_idx2=*/3,
321                /*dim1_copy_size=*/4, /*dim2_copy_size=*/4);
322
323  // Copy dst starts (4, 3), size (4, 4).
324  std::initializer_list<float> tflite_lstm_c2r_weight = {
325      0.049269, 0.156108,  0.093459,  -0.129103, 0.4117,    -0.344439,
326      0.240465, -0.343331, -0.186592, -0.020756, -0.239007, 0.364817,
327      0.204032, -0.375317, -0.041911, 0.051664};
328  PopulateBuffer(GetSrcBuffer(), tflite_lstm_c2r_weight);
329  CopyArrayData(*(GetSrcBuffer()),
330                /*src_stride=*/4, /*src_start_idx1=*/0,
331                /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/7,
332                /*dst_start_idx1=*/4, /*dst_start_idx2=*/3,
333                /*dim1_copy_size=*/4, /*dim2_copy_size=*/4);
334
335  // Copy dst starts (8, 3), size (4, 4).
336  std::initializer_list<float> tflite_lstm_f2r_weight = {
337      -0.249823, -0.353107, 0.031563,  -0.340771, -0.50141,  0.486939,
338      -0.43853,  0.268426,  -0.028055, -0.121838, -0.046016, 0.105309,
339      0.048654,  -0.38582,  0.411018,  -0.315606};
340  PopulateBuffer(GetSrcBuffer(), tflite_lstm_f2r_weight);
341  CopyArrayData(*(GetSrcBuffer()),
342                /*src_stride=*/4, /*src_start_idx1=*/0,
343                /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/7,
344                /*dst_start_idx1=*/8, /*dst_start_idx2=*/3,
345                /*dim1_copy_size=*/4, /*dim2_copy_size=*/4);
346
347  // Copy dst starts (12, 3), size (4, 4).
348  std::initializer_list<float> tflite_lstm_o2r_weight = {
349      -0.097902, 0.331218,  0.034602, 0.418069, 0.393821,  0.404733,
350      -0.055418, -0.43903,  0.459869, 0.143755, -0.177335, -0.162247,
351      -0.446775, -0.418363, 0.019743, 0.042025};
352  PopulateBuffer(GetSrcBuffer(), tflite_lstm_o2r_weight);
353  CopyArrayData(*(GetSrcBuffer()),
354                /*src_stride=*/4, /*src_start_idx1=*/0,
355                /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/7,
356                /*dst_start_idx1=*/12, /*dst_start_idx2=*/3,
357                /*dim1_copy_size=*/4, /*dim2_copy_size=*/4);
358
359  std::vector<float> expected = {
360      -0.320407, -0.108683, 0.406358,  -0.410811, -0.285786, -0.15769,
361      -0.194201, 0.170866,  0.084135,  0.201878,  0.21519,   -0.284458,
362      0.495906,  -0.073818, 0.045578,  0.149816,  -0.447073, -0.453578,
363      0.116766,  0.21808,   0.047326,  -0.001985, 0.402193,  0.315517,
364      0.38258,   0.43599,   0.11986,   0.465195,  0.33548,   -0.118789,
365      -0.414159, 0.049269,  0.156108,  0.093459,  -0.129103, -0.086274,
366      0.186188,  -0.324923, 0.4117,    -0.344439, 0.240465,  -0.343331,
367      -0.463082, -0.231706, -0.487465, -0.186592, -0.020756, -0.239007,
368      0.364817,  0.459106,  -0.171447, -0.006542, 0.204032,  -0.375317,
369      -0.041911, 0.051664,  0.320483,  0.155899,  0.156555,  -0.249823,
370      -0.353107, 0.031563,  -0.340771, -0.052532, 0.134631,  -0.257957,
371      -0.50141,  0.486939,  -0.43853,  0.268426,  -0.08754,  -0.109447,
372      -0.502462, -0.028055, -0.121838, -0.046016, 0.105309,  -0.070774,
373      0.495683,  -0.475088, 0.048654,  -0.38582,  0.411018,  -0.315606,
374      0.349628,  0.21698,   0.258989,  -0.097902, 0.331218,  0.034602,
375      0.418069,  -0.089025, -0.417513, 0.07609,   0.393821,  0.404733,
376      -0.055418, -0.43903,  -0.447049, 0.013125,  0.278503,  0.459869,
377      0.143755,  -0.177335, -0.162247, -0.432371, 0.153714,  -0.047403,
378      -0.446775, -0.418363, 0.019743,  0.042025};
379
380  EXPECT_THAT(GetDstBuffer()->data, ElementsAreArray(ArrayFloatNear(expected)));
381}
382
383// Copy from 4 small 1D arrayes to 1 big one.
384TEST_F(CopyArrayDataTest, CopyFromSmallArrayesToBigArray1D) {
385  // Init src_buffer, dst_buffer.
386  Model model;
387  std::initializer_list<float> large_tf_bias_data = {0, 0, 0, 0, 0, 0, 0, 0,
388                                                     0, 0, 0, 0, 0, 0, 0, 0};
389
390  std::initializer_list<float> tflite_lstm_i_bias = {0.980304, 0.419808,
391                                                     0.080278, 0.728548};
392
393  PrepareBuffers(&model, tflite_lstm_i_bias, /*src_dim_1=*/4,
394                 /*src_dim_2=*/1, large_tf_bias_data,
395                 /*dst_dim_1=*/16, /*dst_dim_2=*/1);
396
397  // Copy starts at (0,), size (4,).
398  CopyArrayData(*(GetSrcBuffer()),
399                /*src_stride=*/1, /*src_start_idx1=*/0,
400                /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/1,
401                /*dst_start_idx1=*/0, /*dst_start_idx2=*/0,
402                /*dim1_copy_size=*/4, /*dim2_copy_size=*/1);
403
404  // Copy starts at (4,), size (4,).
405  std::initializer_list<float> tflite_lstm_cell_bias = {0.581674, 0.672433,
406                                                        0.434190, 0.844357};
407  PopulateBuffer(GetSrcBuffer(), tflite_lstm_cell_bias);
408  CopyArrayData(*(GetSrcBuffer()),
409                /*src_stride=*/1, /*src_start_idx1=*/0,
410                /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/1,
411                /*dst_start_idx1=*/4, /*dst_start_idx2=*/0,
412                /*dim1_copy_size=*/4, /*dim2_copy_size=*/1);
413
414  // Copy starts at (8,0), size (4,).
415  std::initializer_list<float> tflite_lstm_forget_bias = {0.229587, 0.785629,
416                                                          0.022065, 0.753082};
417  PopulateBuffer(GetSrcBuffer(), tflite_lstm_forget_bias);
418  CopyArrayData(*(GetSrcBuffer()),
419                /*src_stride=*/1, /*src_start_idx1=*/0,
420                /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/1,
421                /*dst_start_idx1=*/8, /*dst_start_idx2=*/0,
422                /*dim1_copy_size=*/4, /*dim2_copy_size=*/1);
423
424  // Copy starts at (12,), size (4,).
425  std::initializer_list<float> tflite_lstm_output_bias = {0.422080, 0.539481,
426                                                          0.878386, 0.168965};
427  PopulateBuffer(GetSrcBuffer(), tflite_lstm_output_bias);
428  CopyArrayData(*(GetSrcBuffer()),
429                /*src_stride=*/1, /*src_start_idx1=*/0,
430                /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/1,
431                /*dst_start_idx1=*/12, /*dst_start_idx2=*/0,
432                /*dim1_copy_size=*/4, /*dim2_copy_size=*/1);
433
434  std::vector<float> expected = {0.980304, 0.419808, 0.080278, 0.728548,
435                                 0.581674, 0.672433, 0.434190, 0.844357,
436                                 0.229587, 0.785629, 0.022065, 0.753082,
437                                 0.422080, 0.539481, 0.878386, 0.168965};
438
439  EXPECT_THAT(GetDstBuffer()->data, ElementsAreArray(ArrayFloatNear(expected)));
440}
441
442}  // namespace toco
443