10b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
20b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
30b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleLicensed under the Apache License, Version 2.0 (the "License");
40b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selleyou may not use this file except in compliance with the License.
50b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleYou may obtain a copy of the License at
60b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
70b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    http://www.apache.org/licenses/LICENSE-2.0
80b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
90b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleUnless required by applicable law or agreed to in writing, software
100b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selledistributed under the License is distributed on an "AS IS" BASIS,
110b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
120b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleSee the License for the specific language governing permissions and
130b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellelimitations under the License.
140b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle==============================================================================*/
150b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <memory>
160b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <string>
170b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <unordered_map>
180b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <vector>
190b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
200b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <gmock/gmock.h>
210b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <gtest/gtest.h>
220b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
230b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/contrib/lite/toco/model.h"
240b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/contrib/lite/toco/tooling_util.h"
250b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
260b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellenamespace toco {
270b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
280b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellenamespace {
290b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle// A gmock matcher that check that elements of a float vector match to a given
300b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle// tolerance.
310b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellestd::vector<testing::Matcher<float>> ArrayFloatNear(
320b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    const std::vector<float>& values, float max_abs_error = 1e-5) {
330b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  std::vector<testing::Matcher<float>> matchers;
340b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  matchers.reserve(values.size());
350b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  for (const float& v : values) {
360b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    matchers.emplace_back(testing::FloatNear(v, max_abs_error));
370b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
380b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  return matchers;
390b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle}
400b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle}  // namespace
410b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
420b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle// The following 3 tests make sure the concatenation operation on different axis
430b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle// values match TensorFlow results listed below:
440b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle//
450b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle// x0 = [[[0, 1], [2, 3]], [[4, 5], [6, 7]]]
460b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle// x1 = [[[10, 11], [12, 13]], [[14, 15], [16, 17]]]
470b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle// x2 = [[[20, 21], [22, 23]], [[24, 25], [26, 27]]]
480b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle// x3 = [[[30, 31], [32, 33]], [[34, 35], [36, 37]]]
490b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle//
500b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle// ConcatAtAxis0 test:
510b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle// t0 = tf.concat([x0, x1, x2, x3], 0)
520b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle// [[[ 0  1]
530b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle//   [ 2  3]]
540b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle//
550b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle//  [[ 4  5]
560b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle//   [ 6  7]]
570b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle//
580b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle//  [[10 11]
590b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle//   [12 13]]
600b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle//
610b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle//  [[14 15]
620b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle//   [16 17]]
630b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle//
640b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle//  [[20 21]
650b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle//   [22 23]]
660b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle//
670b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle//  [[24 25]
680b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle//   [26 27]]
690b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle//
700b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle//  [[30 31]
710b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle//   [32 33]]
720b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle//
730b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle//  [[34 35]
740b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle//   [36 37]]]
750b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle//
760b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle// ConcatAtAxis1 test:
770b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle// t1 = tf.concat([x0, x1, x2, x3], 1)
780b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle// [[[ 0  1]
790b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle//   [ 2  3]
800b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle//   [10 11]
810b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle//   [12 13]
820b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle//   [20 21]
830b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle//   [22 23]
840b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle//   [30 31]
850b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle//   [32 33]]
860b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle//
870b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle//  [[ 4  5]
880b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle//   [ 6  7]
890b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle//   [14 15]
900b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle//   [16 17]
910b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle//   [24 25]
920b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle//   [26 27]
930b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle//   [34 35]
940b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle//   [36 37]]]
950b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle//
960b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle// ConcatAtAxis2 test:
970b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle// t2 = tf.concat([x0, x1, x2, x3], 2)
980b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle// [[[ 0  1 10 11 20 21 30 31]
990b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle//   [ 2  3 12 13 22 23 32 33]]
1000b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle//
1010b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle//  [[ 4  5 14 15 24 25 34 35]
1020b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle//   [ 6  7 16 17 26 27 36 37]]]
1030b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
1040b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selleclass ResolveConstantConcatenationTest : public ::testing::Test {
1050b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle protected:
1060b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  ResolveConstantConcatenationTest() {}
1070b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
1080b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // Prepare a hypothetical TOCO model with one Concatenation operator in it
1090b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // together with 4 arrays as its inputs.
1100b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // It receives the dimension of concatenation as input.
1112eae1ac21ce28f3b2cafe9e12a25b3bddc475847A. Unique TensorFlower  void PrepareModel(Model* model, int axis) {
1120b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    std::vector<string> concat_input_names = {"array0", "array1", "array2",
1130b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                                              "array3"};
1140b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
1150b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    const int kDim = 3;
1160b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    const int kElementPerDim = 2;
1170b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    const int kBufSize = 8;
1180b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    const int kNumArrays = 4;
1190b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    static float in_buf[kNumArrays][kBufSize] = {
1200b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle        {0., 1., 2., 3., 4., 5., 6., 7.},
1210b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle        {10., 11., 12., 13., 14., 15., 16., 17.},
1220b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle        {20., 21., 22., 23., 24., 25., 26., 27.},
1230b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle        {30., 31., 32., 33., 34., 35., 36., 37.}};
1240b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    int cnt = 0;
1250b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    for (const string& concat_input_name : concat_input_names) {
1260b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      Array& in_array = model->GetOrCreateArray(concat_input_name);
1270b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      in_array.data_type = ArrayDataType::kFloat;
1280b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
1290b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      // Initialize shape for the input  array.
1300b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      Shape* in_array_shape = in_array.mutable_shape();
1310b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      std::vector<int>* in_array_shape_dim = in_array_shape->mutable_dims();
1320b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      for (int i = 0; i < kDim; i++) {
1330b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle        in_array_shape_dim->push_back(kElementPerDim);
1340b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      }
1350b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      auto& in_array_buffer =
1360b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle          in_array.GetMutableBuffer<toco::ArrayDataType::kFloat>();
1370b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      in_array_buffer.data.resize(kBufSize);
1380b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      float* buf_ptr =
1390b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle          in_array.GetMutableBuffer<toco::ArrayDataType::kFloat>().data.data();
1400b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      std::copy(in_buf[cnt], in_buf[cnt] + kBufSize, buf_ptr);
1410b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      cnt++;
1420b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    }
1430b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    auto* concatenation_op = new ConcatenationOperator;
1442eae1ac21ce28f3b2cafe9e12a25b3bddc475847A. Unique TensorFlower    concatenation_op->axis = axis;
1450b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    concatenation_op->inputs = concat_input_names;
1460b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    concatenation_op->outputs = {"concat_op_outputs"};
1470b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    Array& out_array = model->GetOrCreateArray(concatenation_op->outputs[0]);
1480b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    out_array.data_type = ArrayDataType::kFloat;
1490b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    Shape* out_array_shape = out_array.mutable_shape();
1500b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    std::vector<int>* out_array_shape_dim = out_array_shape->mutable_dims();
1510b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    out_array_shape_dim->resize(kDim);
1520b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    for (int i = 0; i < kDim; i++) {
1532eae1ac21ce28f3b2cafe9e12a25b3bddc475847A. Unique TensorFlower      if (i == axis) {
1540b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle        (*out_array_shape_dim)[i] = kNumArrays * kElementPerDim;
1550b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      } else {
1560b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle        (*out_array_shape_dim)[i] = kElementPerDim;
1570b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      }
1580b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    }
1590b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    model->operators.push_back(std::unique_ptr<Operator>(concatenation_op));
1600b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
1610b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle};
1620b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
1630b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleTEST_F(ResolveConstantConcatenationTest, ConcatAtAxis0) {
1640b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  Model model;
1652eae1ac21ce28f3b2cafe9e12a25b3bddc475847A. Unique TensorFlower  const int axis = 0;
1662eae1ac21ce28f3b2cafe9e12a25b3bddc475847A. Unique TensorFlower  PrepareModel(&model, axis);
1670b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
1680b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  GraphTransformationsSet graph_transformation_set;
1690b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  graph_transformation_set.Add(new toco::ResolveConstantConcatenation);
170ba4aec48268d02f111cd7e2c2666f4e7b077e68aA. Unique TensorFlower  EXPECT_THAT(model.GetArrayMap().size(), 5);
1710b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  (*graph_transformation_set.begin())->Run(&model, /*op_index=*/0);
172ba4aec48268d02f111cd7e2c2666f4e7b077e68aA. Unique TensorFlower  EXPECT_THAT(model.GetArrayMap().size(), 1);
1730b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
174ba4aec48268d02f111cd7e2c2666f4e7b077e68aA. Unique TensorFlower  auto& concatenated_array = (*model.GetArrayMap().begin()).second;
1750b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  EXPECT_THAT(concatenated_array->GetBuffer<toco::ArrayDataType::kFloat>().data,
1760b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle              ElementsAreArray(ArrayFloatNear(
1770b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                  {0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  10., 11., 12.,
1780b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                   13., 14., 15., 16., 17., 20., 21., 22., 23., 24., 25.,
1790b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                   26., 27., 30., 31., 32., 33., 34., 35., 36., 37.})));
1800b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle}
1810b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
1820b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleTEST_F(ResolveConstantConcatenationTest, ConcatAtAxis1) {
1830b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  Model model;
1842eae1ac21ce28f3b2cafe9e12a25b3bddc475847A. Unique TensorFlower  const int axis = 1;
1852eae1ac21ce28f3b2cafe9e12a25b3bddc475847A. Unique TensorFlower  PrepareModel(&model, axis);
1860b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
1870b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  GraphTransformationsSet graph_transformation_set;
1880b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  graph_transformation_set.Add(new toco::ResolveConstantConcatenation);
189ba4aec48268d02f111cd7e2c2666f4e7b077e68aA. Unique TensorFlower  EXPECT_THAT(model.GetArrayMap().size(), 5);
1900b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  (*graph_transformation_set.begin())->Run(&model, /*op_index=*/0);
191ba4aec48268d02f111cd7e2c2666f4e7b077e68aA. Unique TensorFlower  EXPECT_THAT(model.GetArrayMap().size(), 1);
1920b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
193ba4aec48268d02f111cd7e2c2666f4e7b077e68aA. Unique TensorFlower  auto& concatenated_array = (*model.GetArrayMap().begin()).second;
1940b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  EXPECT_THAT(concatenated_array->GetBuffer<toco::ArrayDataType::kFloat>().data,
1950b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle              ElementsAreArray(ArrayFloatNear(
1960b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                  {0.,  1.,  2.,  3.,  10., 11., 12., 13., 20., 21., 22.,
1970b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                   23., 30., 31., 32., 33., 4.,  5.,  6.,  7.,  14., 15.,
1980b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                   16., 17., 24., 25., 26., 27., 34., 35., 36., 37.})));
1990b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle}
2000b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
2010b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleTEST_F(ResolveConstantConcatenationTest, ConcatAtAxis2) {
2020b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  Model model;
2032eae1ac21ce28f3b2cafe9e12a25b3bddc475847A. Unique TensorFlower  const int axis = 2;
2042eae1ac21ce28f3b2cafe9e12a25b3bddc475847A. Unique TensorFlower  PrepareModel(&model, axis);
2050b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
2060b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  GraphTransformationsSet graph_transformation_set;
2070b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  graph_transformation_set.Add(new toco::ResolveConstantConcatenation);
208ba4aec48268d02f111cd7e2c2666f4e7b077e68aA. Unique TensorFlower  EXPECT_THAT(model.GetArrayMap().size(), 5);
2090b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  (*graph_transformation_set.begin())->Run(&model, /*op_index=*/0);
210ba4aec48268d02f111cd7e2c2666f4e7b077e68aA. Unique TensorFlower  EXPECT_THAT(model.GetArrayMap().size(), 1);
2110b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
212ba4aec48268d02f111cd7e2c2666f4e7b077e68aA. Unique TensorFlower  auto& concatenated_array = (*model.GetArrayMap().begin()).second;
2130b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  EXPECT_THAT(concatenated_array->GetBuffer<toco::ArrayDataType::kFloat>().data,
2140b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle              ElementsAreArray(ArrayFloatNear(
2150b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                  {0.,  1.,  10., 11., 20., 21., 30., 31., 2.,  3.,  12.,
2160b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                   13., 22., 23., 32., 33., 4.,  5.,  14., 15., 24., 25.,
2170b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                   34., 35., 6.,  7.,  16., 17., 26., 27., 36., 37.})));
2180b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle}
2190b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
2200b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle}  // namespace toco
221