1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15#include <tuple> 16#include <vector> 17 18#include <gmock/gmock.h> 19#include <gtest/gtest.h> 20#include "tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h" 21#include "tensorflow/contrib/lite/toco/model.h" 22#include "tensorflow/contrib/lite/toco/tooling_util.h" 23 24namespace toco { 25 26namespace { 27 28// A gmock matcher that check that elements of a float vector match to a given 29// tolerance. 30std::vector<testing::Matcher<float>> ArrayFloatNear( 31 const std::vector<float>& values, float max_abs_error = 1e-5) { 32 std::vector<testing::Matcher<float>> matchers; 33 matchers.reserve(values.size()); 34 for (const float& v : values) { 35 matchers.emplace_back(testing::FloatNear(v, max_abs_error)); 36 } 37 return matchers; 38} 39} // namespace 40 41class CopyArrayDataTest : public ::testing::Test { 42 public: 43 CopyArrayDataTest() {} 44 45 void PrepareBuffers(Model* model, std::initializer_list<float> src_data, 46 int src_dim_1, int src_dim_2, 47 std::initializer_list<float> dst_data, int dst_dim_1, 48 int dst_dim_2) { 49 string src_array = "src_array"; 50 src_buffer_ = CreateFloatArrayBuffer( 51 model, &src_array, 52 src_dim_2 == 1 ? Shape({src_dim_1}) : Shape({src_dim_1, src_dim_2})); 53 PopulateBuffer(src_buffer_, src_data); 54 string dst_array = "dst_array"; 55 dst_buffer_ = CreateFloatArrayBuffer( 56 model, &dst_array, 57 dst_dim_2 == 1 ? Shape({dst_dim_1}) : Shape({dst_dim_1, dst_dim_2})); 58 PopulateBuffer(dst_buffer_, dst_data); 59 } 60 61 Buffer<ArrayDataType::kFloat>* GetSrcBuffer() { return src_buffer_; } 62 Buffer<ArrayDataType::kFloat>* GetDstBuffer() { return dst_buffer_; } 63 64 void PopulateBuffer(Buffer<ArrayDataType::kFloat>* buffer, 65 const std::vector<float>& init_data) { 66 for (int i = 0; i < init_data.size(); i++) { 67 buffer->data[i] = init_data[i]; 68 } 69 } 70 void UpdateBuffer(Buffer<ArrayDataType::kFloat>* buffer, 71 std::initializer_list<float> data) { 72 buffer->data.resize(data.size()); 73 PopulateBuffer(buffer, data); 74 } 75 76 private: 77 Buffer<ArrayDataType::kFloat>* src_buffer_; 78 Buffer<ArrayDataType::kFloat>* dst_buffer_; 79}; 80 81// Copy from 1 big 2D array to 8 smaller ones. 82TEST_F(CopyArrayDataTest, CopyFromBigArrayToSmallerArrayes2D) { 83 // Init src_buffer, dst_buffer. 84 Model model; 85 std::initializer_list<float> large_tf_weight_data = { 86 -0.320407, -0.108683, 0.406358, -0.410811, -0.285786, -0.15769, 87 -0.194201, 0.170866, 0.084135, 0.201878, 0.21519, -0.284458, 88 0.495906, -0.073818, 0.045578, 0.149816, -0.447073, -0.453578, 89 0.116766, 0.21808, 0.047326, -0.001985, 0.402193, 0.315517, 90 0.38258, 0.43599, 0.11986, 0.465195, 0.33548, -0.118789, 91 -0.414159, 0.049269, 0.156108, 0.093459, -0.129103, -0.086274, 92 0.186188, -0.324923, 0.4117, -0.344439, 0.240465, -0.343331, 93 -0.463082, -0.231706, -0.487465, -0.186592, -0.020756, -0.239007, 94 0.364817, 0.459106, -0.171447, -0.006542, 0.204032, -0.375317, 95 -0.041911, 0.051664, 0.320483, 0.155899, 0.156555, -0.249823, 96 -0.353107, 0.031563, -0.340771, -0.052532, 0.134631, -0.257957, 97 -0.50141, 0.486939, -0.43853, 0.268426, -0.08754, -0.109447, 98 -0.502462, -0.028055, -0.121838, -0.046016, 0.105309, -0.070774, 99 0.495683, -0.475088, 0.048654, -0.38582, 0.411018, -0.315606, 100 0.349628, 0.21698, 0.258989, -0.097902, 0.331218, 0.034602, 101 0.418069, -0.089025, -0.417513, 0.07609, 0.393821, 0.404733, 102 -0.055418, -0.43903, -0.447049, 0.013125, 0.278503, 0.459869, 103 0.143755, -0.177335, -0.162247, -0.432371, 0.153714, -0.047403, 104 -0.446775, -0.418363, 0.019743, 0.042025}; 105 std::initializer_list<float> tflite_lstm_input_weight = {0, 0, 0, 0, 0, 0, 106 0, 0, 0, 0, 0, 0}; 107 PrepareBuffers(&model, large_tf_weight_data, /*src_dim_1=*/16, 108 /*src_dim_2=*/7, tflite_lstm_input_weight, 109 /*dst_dim_1=*/4, /*dst_dim_2=*/3); 110 111 // Copy src starts at (0,0), size (4,3). 112 CopyArrayData(*(GetSrcBuffer()), 113 /*src_stride=*/7, /*src_start_idx1=*/0, 114 /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/3, 115 /*dst_start_idx1=*/0, /*dst_start_idx2=*/0, 116 /*dim1_copy_size=*/4, /*dim2_copy_size=*/3); 117 std::vector<float> expected = {-0.320407, -0.108683, 0.406358, 0.170866, 118 0.084135, 0.201878, 0.045578, 0.149816, 119 -0.447073, -0.001985, 0.402193, 0.315517}; 120 EXPECT_THAT(GetDstBuffer()->data, ElementsAreArray(ArrayFloatNear(expected))); 121 122 // Copy src starts at (4,0), size (4,3). 123 CopyArrayData(*(GetSrcBuffer()), 124 /*src_stride=*/7, /*src_start_idx1=*/4, 125 /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/3, 126 /*dst_start_idx1=*/0, /*dst_start_idx2=*/0, 127 /*dim1_copy_size=*/4, /*dim2_copy_size=*/3); 128 expected = {0.33548, -0.118789, -0.414159, -0.086274, 0.186188, -0.324923, 129 -0.463082, -0.231706, -0.487465, 0.459106, -0.171447, -0.006542}; 130 EXPECT_THAT(GetDstBuffer()->data, ElementsAreArray(ArrayFloatNear(expected))); 131 132 // Copy src starts at (8,0), size (4,3). 133 CopyArrayData(*(GetSrcBuffer()), 134 /*src_stride=*/7, /*src_start_idx1=*/8, 135 /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/3, 136 /*dst_start_idx1=*/0, /*dst_start_idx2=*/0, 137 /*dim1_copy_size=*/4, /*dim2_copy_size=*/3); 138 expected = {0.320483, 0.155899, 0.156555, -0.052532, 0.134631, -0.257957, 139 -0.08754, -0.109447, -0.502462, -0.070774, 0.495683, -0.475088}; 140 EXPECT_THAT(GetDstBuffer()->data, ElementsAreArray(ArrayFloatNear(expected))); 141 142 // Copy src starts at (12,0), size (4,3). 143 CopyArrayData(*(GetSrcBuffer()), 144 /*src_stride=*/7, /*src_start_idx1=*/12, 145 /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/3, 146 /*dst_start_idx1=*/0, /*dst_start_idx2=*/0, 147 /*dim1_copy_size=*/4, /*dim2_copy_size=*/3); 148 expected = {0.349628, 0.21698, 0.258989, -0.089025, -0.417513, 0.07609, 149 -0.447049, 0.013125, 0.278503, -0.432371, 0.153714, -0.047403}; 150 EXPECT_THAT(GetDstBuffer()->data, ElementsAreArray(ArrayFloatNear(expected))); 151 152 // New dst_buffer with size 16. 153 std::initializer_list<float> tflite_lstm_recurrent_weight = { 154 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; 155 PrepareBuffers(&model, large_tf_weight_data, /*src_dim_1=*/16, 156 /*src_dim_2=*/7, tflite_lstm_recurrent_weight, 157 /*dst_dim_1=*/4, /*dst_dim_2=*/4); 158 159 // Copy src starts at (0,3), size (4,4). 160 CopyArrayData(*(GetSrcBuffer()), 161 /*src_stride=*/7, /*src_start_idx1=*/0, 162 /*src_start_idx2=*/3, GetDstBuffer(), /*dst_stride=*/4, 163 /*dst_start_idx1=*/0, /*dst_start_idx2=*/0, 164 /*dim1_copy_size=*/4, /*dim2_copy_size=*/4); 165 expected = {-0.410811, -0.285786, -0.15769, -0.194201, 0.21519, -0.284458, 166 0.495906, -0.073818, -0.453578, 0.116766, 0.21808, 0.047326, 167 0.38258, 0.43599, 0.11986, 0.465195}; 168 EXPECT_THAT(GetDstBuffer()->data, ElementsAreArray(ArrayFloatNear(expected))); 169 170 // Copy src starts at (4,3), size (4,4). 171 CopyArrayData(*(GetSrcBuffer()), 172 /*src_stride=*/7, /*src_start_idx1=*/4, 173 /*src_start_idx2=*/3, GetDstBuffer(), /*dst_stride=*/4, 174 /*dst_start_idx1=*/0, /*dst_start_idx2=*/0, 175 /*dim1_copy_size=*/4, /*dim2_copy_size=*/4); 176 expected = {0.049269, 0.156108, 0.093459, -0.129103, 0.4117, -0.344439, 177 0.240465, -0.343331, -0.186592, -0.020756, -0.239007, 0.364817, 178 0.204032, -0.375317, -0.041911, 0.051664}; 179 EXPECT_THAT(GetDstBuffer()->data, ElementsAreArray(ArrayFloatNear(expected))); 180 181 // Copy src starts at (8,3), size (4,4). 182 CopyArrayData(*(GetSrcBuffer()), 183 /*src_stride=*/7, /*src_start_idx1=*/8, 184 /*src_start_idx2=*/3, GetDstBuffer(), /*dst_stride=*/4, 185 /*dst_start_idx1=*/0, /*dst_start_idx2=*/0, 186 /*dim1_copy_size=*/4, /*dim2_copy_size=*/4); 187 expected = {-0.249823, -0.353107, 0.031563, -0.340771, -0.50141, 0.486939, 188 -0.43853, 0.268426, -0.028055, -0.121838, -0.046016, 0.105309, 189 0.048654, -0.38582, 0.411018, -0.315606}; 190 EXPECT_THAT(GetDstBuffer()->data, ElementsAreArray(ArrayFloatNear(expected))); 191 192 // Copy src starts at (12,3), size (4,4). 193 CopyArrayData(*(GetSrcBuffer()), 194 /*src_stride=*/7, /*src_start_idx1=*/12, 195 /*src_start_idx2=*/3, GetDstBuffer(), /*dst_stride=*/4, 196 /*dst_start_idx1=*/0, /*dst_start_idx2=*/0, 197 /*dim1_copy_size=*/4, /*dim2_copy_size=*/4); 198 expected = {-0.097902, 0.331218, 0.034602, 0.418069, 0.393821, 0.404733, 199 -0.055418, -0.43903, 0.459869, 0.143755, -0.177335, -0.162247, 200 -0.446775, -0.418363, 0.019743, 0.042025}; 201 EXPECT_THAT(GetDstBuffer()->data, ElementsAreArray(ArrayFloatNear(expected))); 202} 203 204// Copy from 1 big 1D array to 4 small ones. 205TEST_F(CopyArrayDataTest, CopyFromBigArrayToSmallerArrayes1D) { 206 // Init src_buffer, dst_buffer. 207 Model model; 208 std::initializer_list<float> large_tf_bias_data = { 209 0.980304, 0.419808, 0.080278, 0.728548, 0.581674, 0.672433, 210 0.434190, 0.844357, 0.229587, 0.785629, 0.022065, 0.753082, 211 0.422080, 0.539481, 0.878386, 0.168965}; 212 std::initializer_list<float> tflite_lstm_i_bias = {0, 0, 0, 0}; 213 PrepareBuffers(&model, large_tf_bias_data, /*src_dim_1=*/16, 214 /*src_dim_2=*/1, tflite_lstm_i_bias, 215 /*dst_dim_1=*/4, /*dst_dim_2=*/1); 216 217 // Copy starts at (0,), size (4,). 218 CopyArrayData(*(GetSrcBuffer()), 219 /*src_stride=*/1, /*src_start_idx1=*/0, 220 /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/1, 221 /*dst_start_idx1=*/0, /*dst_start_idx2=*/0, 222 /*dim1_copy_size=*/4, /*dim2_copy_size=*/1); 223 std::vector<float> expected = {0.980304, 0.419808, 0.080278, 0.728548}; 224 EXPECT_THAT(GetDstBuffer()->data, ElementsAreArray(ArrayFloatNear(expected))); 225 226 // Copy starts at (4,), size (4,). 227 CopyArrayData(*(GetSrcBuffer()), 228 /*src_stride=*/1, /*src_start_idx1=*/4, 229 /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/1, 230 /*dst_start_idx1=*/0, /*dst_start_idx2=*/0, 231 /*dim1_copy_size=*/4, /*dim2_copy_size=*/1); 232 expected = {0.581674, 0.672433, 0.434190, 0.844357}; 233 EXPECT_THAT(GetDstBuffer()->data, ElementsAreArray(ArrayFloatNear(expected))); 234 235 // Copy starts at (8,), size (4,). 236 CopyArrayData(*(GetSrcBuffer()), 237 /*src_stride=*/1, /*src_start_idx1=*/8, 238 /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/1, 239 /*dst_start_idx1=*/0, /*dst_start_idx2=*/0, 240 /*dim1_copy_size=*/4, /*dim2_copy_size=*/1); 241 expected = {0.229587, 0.785629, 0.022065, 0.753082}; 242 EXPECT_THAT(GetDstBuffer()->data, ElementsAreArray(ArrayFloatNear(expected))); 243 244 // Copy starts at (12,), size (4,). 245 CopyArrayData(*(GetSrcBuffer()), 246 /*src_stride=*/1, /*src_start_idx1=*/12, 247 /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/1, 248 /*dst_start_idx1=*/0, /*dst_start_idx2=*/0, 249 /*dim1_copy_size=*/4, /*dim2_copy_size=*/1); 250 expected = {0.422080, 0.539481, 0.878386, 0.168965}; 251 EXPECT_THAT(GetDstBuffer()->data, ElementsAreArray(ArrayFloatNear(expected))); 252} 253 254// Copy from 8 small 2D arrayes to 1 big one. 255TEST_F(CopyArrayDataTest, CopyFromSmallArrayesToBigArray2D) { 256 // Init src_buffer, dst_buffer. 257 Model model; 258 std::initializer_list<float> large_tf_weights_data = { 259 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 260 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 261 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 262 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 263 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; 264 265 // Copy dst starts (0, 0), size (4, 3). 266 std::initializer_list<float> tflite_lstm_i2i_weight = { 267 -0.320407, -0.108683, 0.406358, 0.170866, 0.084135, 0.201878, 268 0.045578, 0.149816, -0.447073, -0.001985, 0.402193, 0.315517}; 269 PrepareBuffers(&model, tflite_lstm_i2i_weight, /*src_dim_1=*/4, 270 /*src_dim_2=*/3, large_tf_weights_data, 271 /*dst_dim_1=*/16, /*dst_dim_2=*/7); 272 CopyArrayData(*(GetSrcBuffer()), 273 /*src_stride=*/3, /*src_start_idx1=*/0, 274 /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/7, 275 /*dst_start_idx1=*/0, /*dst_start_idx2=*/0, 276 /*dim1_copy_size=*/4, /*dim2_copy_size=*/3); 277 278 // Copy dst starts (4, 0), size (4, 3). 279 std::initializer_list<float> tflite_lstm_i2c_weight = { 280 0.33548, -0.118789, -0.414159, -0.086274, 0.186188, -0.324923, 281 -0.463082, -0.231706, -0.487465, 0.459106, -0.171447, -0.006542}; 282 PopulateBuffer(GetSrcBuffer(), tflite_lstm_i2c_weight); 283 CopyArrayData(*(GetSrcBuffer()), 284 /*src_stride=*/3, /*src_start_idx1=*/0, 285 /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/7, 286 /*dst_start_idx1=*/4, /*dst_start_idx2=*/0, 287 /*dim1_copy_size=*/4, /*dim2_copy_size=*/3); 288 289 // Copy dst starts (8, 0), size (4, 3). 290 std::initializer_list<float> tflite_lstm_i2f_weight = { 291 0.320483, 0.155899, 0.156555, -0.052532, 0.134631, -0.257957, 292 -0.08754, -0.109447, -0.502462, -0.070774, 0.495683, -0.475088}; 293 PopulateBuffer(GetSrcBuffer(), tflite_lstm_i2f_weight); 294 CopyArrayData(*(GetSrcBuffer()), 295 /*src_stride=*/3, /*src_start_idx1=*/0, 296 /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/7, 297 /*dst_start_idx1=*/8, /*dst_start_idx2=*/0, 298 /*dim1_copy_size=*/4, /*dim2_copy_size=*/3); 299 300 // Copy dst starts (12, 0), size (4, 3). 301 std::initializer_list<float> tflite_lstm_i2o_weight = { 302 0.349628, 0.21698, 0.258989, -0.089025, -0.417513, 0.07609, 303 -0.447049, 0.013125, 0.278503, -0.432371, 0.153714, -0.047403}; 304 PopulateBuffer(GetSrcBuffer(), tflite_lstm_i2o_weight); 305 CopyArrayData(*(GetSrcBuffer()), 306 /*src_stride=*/3, /*src_start_idx1=*/0, 307 /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/7, 308 /*dst_start_idx1=*/12, /*dst_start_idx2=*/0, 309 /*dim1_copy_size=*/4, /*dim2_copy_size=*/3); 310 311 // Copy dst starts (0, 3), size (4, 4). 312 std::initializer_list<float> tflite_lstm_i2r_weight = { 313 -0.410811, -0.285786, -0.15769, -0.194201, 0.21519, -0.284458, 314 0.495906, -0.073818, -0.453578, 0.116766, 0.21808, 0.047326, 315 0.38258, 0.43599, 0.11986, 0.465195}; 316 UpdateBuffer(GetSrcBuffer(), tflite_lstm_i2r_weight); 317 CopyArrayData(*(GetSrcBuffer()), 318 /*src_stride=*/4, /*src_start_idx1=*/0, 319 /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/7, 320 /*dst_start_idx1=*/0, /*dst_start_idx2=*/3, 321 /*dim1_copy_size=*/4, /*dim2_copy_size=*/4); 322 323 // Copy dst starts (4, 3), size (4, 4). 324 std::initializer_list<float> tflite_lstm_c2r_weight = { 325 0.049269, 0.156108, 0.093459, -0.129103, 0.4117, -0.344439, 326 0.240465, -0.343331, -0.186592, -0.020756, -0.239007, 0.364817, 327 0.204032, -0.375317, -0.041911, 0.051664}; 328 PopulateBuffer(GetSrcBuffer(), tflite_lstm_c2r_weight); 329 CopyArrayData(*(GetSrcBuffer()), 330 /*src_stride=*/4, /*src_start_idx1=*/0, 331 /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/7, 332 /*dst_start_idx1=*/4, /*dst_start_idx2=*/3, 333 /*dim1_copy_size=*/4, /*dim2_copy_size=*/4); 334 335 // Copy dst starts (8, 3), size (4, 4). 336 std::initializer_list<float> tflite_lstm_f2r_weight = { 337 -0.249823, -0.353107, 0.031563, -0.340771, -0.50141, 0.486939, 338 -0.43853, 0.268426, -0.028055, -0.121838, -0.046016, 0.105309, 339 0.048654, -0.38582, 0.411018, -0.315606}; 340 PopulateBuffer(GetSrcBuffer(), tflite_lstm_f2r_weight); 341 CopyArrayData(*(GetSrcBuffer()), 342 /*src_stride=*/4, /*src_start_idx1=*/0, 343 /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/7, 344 /*dst_start_idx1=*/8, /*dst_start_idx2=*/3, 345 /*dim1_copy_size=*/4, /*dim2_copy_size=*/4); 346 347 // Copy dst starts (12, 3), size (4, 4). 348 std::initializer_list<float> tflite_lstm_o2r_weight = { 349 -0.097902, 0.331218, 0.034602, 0.418069, 0.393821, 0.404733, 350 -0.055418, -0.43903, 0.459869, 0.143755, -0.177335, -0.162247, 351 -0.446775, -0.418363, 0.019743, 0.042025}; 352 PopulateBuffer(GetSrcBuffer(), tflite_lstm_o2r_weight); 353 CopyArrayData(*(GetSrcBuffer()), 354 /*src_stride=*/4, /*src_start_idx1=*/0, 355 /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/7, 356 /*dst_start_idx1=*/12, /*dst_start_idx2=*/3, 357 /*dim1_copy_size=*/4, /*dim2_copy_size=*/4); 358 359 std::vector<float> expected = { 360 -0.320407, -0.108683, 0.406358, -0.410811, -0.285786, -0.15769, 361 -0.194201, 0.170866, 0.084135, 0.201878, 0.21519, -0.284458, 362 0.495906, -0.073818, 0.045578, 0.149816, -0.447073, -0.453578, 363 0.116766, 0.21808, 0.047326, -0.001985, 0.402193, 0.315517, 364 0.38258, 0.43599, 0.11986, 0.465195, 0.33548, -0.118789, 365 -0.414159, 0.049269, 0.156108, 0.093459, -0.129103, -0.086274, 366 0.186188, -0.324923, 0.4117, -0.344439, 0.240465, -0.343331, 367 -0.463082, -0.231706, -0.487465, -0.186592, -0.020756, -0.239007, 368 0.364817, 0.459106, -0.171447, -0.006542, 0.204032, -0.375317, 369 -0.041911, 0.051664, 0.320483, 0.155899, 0.156555, -0.249823, 370 -0.353107, 0.031563, -0.340771, -0.052532, 0.134631, -0.257957, 371 -0.50141, 0.486939, -0.43853, 0.268426, -0.08754, -0.109447, 372 -0.502462, -0.028055, -0.121838, -0.046016, 0.105309, -0.070774, 373 0.495683, -0.475088, 0.048654, -0.38582, 0.411018, -0.315606, 374 0.349628, 0.21698, 0.258989, -0.097902, 0.331218, 0.034602, 375 0.418069, -0.089025, -0.417513, 0.07609, 0.393821, 0.404733, 376 -0.055418, -0.43903, -0.447049, 0.013125, 0.278503, 0.459869, 377 0.143755, -0.177335, -0.162247, -0.432371, 0.153714, -0.047403, 378 -0.446775, -0.418363, 0.019743, 0.042025}; 379 380 EXPECT_THAT(GetDstBuffer()->data, ElementsAreArray(ArrayFloatNear(expected))); 381} 382 383// Copy from 4 small 1D arrayes to 1 big one. 384TEST_F(CopyArrayDataTest, CopyFromSmallArrayesToBigArray1D) { 385 // Init src_buffer, dst_buffer. 386 Model model; 387 std::initializer_list<float> large_tf_bias_data = {0, 0, 0, 0, 0, 0, 0, 0, 388 0, 0, 0, 0, 0, 0, 0, 0}; 389 390 std::initializer_list<float> tflite_lstm_i_bias = {0.980304, 0.419808, 391 0.080278, 0.728548}; 392 393 PrepareBuffers(&model, tflite_lstm_i_bias, /*src_dim_1=*/4, 394 /*src_dim_2=*/1, large_tf_bias_data, 395 /*dst_dim_1=*/16, /*dst_dim_2=*/1); 396 397 // Copy starts at (0,), size (4,). 398 CopyArrayData(*(GetSrcBuffer()), 399 /*src_stride=*/1, /*src_start_idx1=*/0, 400 /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/1, 401 /*dst_start_idx1=*/0, /*dst_start_idx2=*/0, 402 /*dim1_copy_size=*/4, /*dim2_copy_size=*/1); 403 404 // Copy starts at (4,), size (4,). 405 std::initializer_list<float> tflite_lstm_cell_bias = {0.581674, 0.672433, 406 0.434190, 0.844357}; 407 PopulateBuffer(GetSrcBuffer(), tflite_lstm_cell_bias); 408 CopyArrayData(*(GetSrcBuffer()), 409 /*src_stride=*/1, /*src_start_idx1=*/0, 410 /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/1, 411 /*dst_start_idx1=*/4, /*dst_start_idx2=*/0, 412 /*dim1_copy_size=*/4, /*dim2_copy_size=*/1); 413 414 // Copy starts at (8,0), size (4,). 415 std::initializer_list<float> tflite_lstm_forget_bias = {0.229587, 0.785629, 416 0.022065, 0.753082}; 417 PopulateBuffer(GetSrcBuffer(), tflite_lstm_forget_bias); 418 CopyArrayData(*(GetSrcBuffer()), 419 /*src_stride=*/1, /*src_start_idx1=*/0, 420 /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/1, 421 /*dst_start_idx1=*/8, /*dst_start_idx2=*/0, 422 /*dim1_copy_size=*/4, /*dim2_copy_size=*/1); 423 424 // Copy starts at (12,), size (4,). 425 std::initializer_list<float> tflite_lstm_output_bias = {0.422080, 0.539481, 426 0.878386, 0.168965}; 427 PopulateBuffer(GetSrcBuffer(), tflite_lstm_output_bias); 428 CopyArrayData(*(GetSrcBuffer()), 429 /*src_stride=*/1, /*src_start_idx1=*/0, 430 /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/1, 431 /*dst_start_idx1=*/12, /*dst_start_idx2=*/0, 432 /*dim1_copy_size=*/4, /*dim2_copy_size=*/1); 433 434 std::vector<float> expected = {0.980304, 0.419808, 0.080278, 0.728548, 435 0.581674, 0.672433, 0.434190, 0.844357, 436 0.229587, 0.785629, 0.022065, 0.753082, 437 0.422080, 0.539481, 0.878386, 0.168965}; 438 439 EXPECT_THAT(GetDstBuffer()->data, ElementsAreArray(ArrayFloatNear(expected))); 440} 441 442} // namespace toco 443