12b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang// This file is part of Eigen, a lightweight C++ template library
22b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang// for linear algebra.
32b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang//
42b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
52b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang//
62b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang// This Source Code Form is subject to the terms of the Mozilla
72b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang// Public License v. 2.0. If a copy of the MPL was not distributed
82b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
92b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang
102b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang#include "main.h"
112b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang
122b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang#include <Eigen/CXX11/Tensor>
132b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang
142b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wangusing Eigen::Tensor;
152b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang
162b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wangtemplate<int DataLayout>
172b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wangstatic void test_dimension_failures()
182b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang{
192b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang  Tensor<int, 3, DataLayout> left(2, 3, 1);
202b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang  Tensor<int, 3, DataLayout> right(3, 3, 1);
212b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang  left.setRandom();
222b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang  right.setRandom();
232b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang
242b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang  // Okay; other dimensions are equal.
252b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang  Tensor<int, 3, DataLayout> concatenation = left.concatenate(right, 0);
262b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang
272b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang  // Dimension mismatches.
282b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang  VERIFY_RAISES_ASSERT(concatenation = left.concatenate(right, 1));
292b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang  VERIFY_RAISES_ASSERT(concatenation = left.concatenate(right, 2));
302b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang
312b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang  // Axis > NumDims or < 0.
322b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang  VERIFY_RAISES_ASSERT(concatenation = left.concatenate(right, 3));
332b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang  VERIFY_RAISES_ASSERT(concatenation = left.concatenate(right, -1));
342b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang}
352b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang
362b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wangtemplate<int DataLayout>
372b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wangstatic void test_static_dimension_failure()
382b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang{
392b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang  Tensor<int, 2, DataLayout> left(2, 3);
402b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang  Tensor<int, 3, DataLayout> right(2, 3, 1);
412b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang
422b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang#ifdef CXX11_TENSOR_CONCATENATION_STATIC_DIMENSION_FAILURE
432b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang  // Technically compatible, but we static assert that the inputs have same
442b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang  // NumDims.
452b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang  Tensor<int, 3, DataLayout> concatenation = left.concatenate(right, 0);
462b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang#endif
472b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang
482b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang  // This can be worked around in this case.
492b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang  Tensor<int, 3, DataLayout> concatenation = left
502b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang      .reshape(Tensor<int, 3>::Dimensions(2, 3, 1))
512b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang      .concatenate(right, 0);
522b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang  Tensor<int, 2, DataLayout> alternative = left
532b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang      .concatenate(right.reshape(Tensor<int, 2>::Dimensions{{{2, 3}}}), 0);
542b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang}
552b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang
562b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wangtemplate<int DataLayout>
572b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wangstatic void test_simple_concatenation()
582b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang{
592b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang  Tensor<int, 3, DataLayout> left(2, 3, 1);
602b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang  Tensor<int, 3, DataLayout> right(2, 3, 1);
612b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang  left.setRandom();
622b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang  right.setRandom();
632b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang
642b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang  Tensor<int, 3, DataLayout> concatenation = left.concatenate(right, 0);
652b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang  VERIFY_IS_EQUAL(concatenation.dimension(0), 4);
662b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang  VERIFY_IS_EQUAL(concatenation.dimension(1), 3);
672b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang  VERIFY_IS_EQUAL(concatenation.dimension(2), 1);
682b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang  for (int j = 0; j < 3; ++j) {
692b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang    for (int i = 0; i < 2; ++i) {
702b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang      VERIFY_IS_EQUAL(concatenation(i, j, 0), left(i, j, 0));
712b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang    }
722b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang    for (int i = 2; i < 4; ++i) {
732b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang      VERIFY_IS_EQUAL(concatenation(i, j, 0), right(i - 2, j, 0));
742b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang    }
752b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang  }
762b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang
772b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang  concatenation = left.concatenate(right, 1);
782b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang  VERIFY_IS_EQUAL(concatenation.dimension(0), 2);
792b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang  VERIFY_IS_EQUAL(concatenation.dimension(1), 6);
802b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang  VERIFY_IS_EQUAL(concatenation.dimension(2), 1);
812b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang  for (int i = 0; i < 2; ++i) {
822b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang    for (int j = 0; j < 3; ++j) {
832b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang      VERIFY_IS_EQUAL(concatenation(i, j, 0), left(i, j, 0));
842b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang    }
852b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang    for (int j = 3; j < 6; ++j) {
862b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang      VERIFY_IS_EQUAL(concatenation(i, j, 0), right(i, j - 3, 0));
872b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang    }
882b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang  }
892b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang
902b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang  concatenation = left.concatenate(right, 2);
912b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang  VERIFY_IS_EQUAL(concatenation.dimension(0), 2);
922b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang  VERIFY_IS_EQUAL(concatenation.dimension(1), 3);
932b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang  VERIFY_IS_EQUAL(concatenation.dimension(2), 2);
942b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang  for (int i = 0; i < 2; ++i) {
952b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang    for (int j = 0; j < 3; ++j) {
962b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang      VERIFY_IS_EQUAL(concatenation(i, j, 0), left(i, j, 0));
972b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang      VERIFY_IS_EQUAL(concatenation(i, j, 1), right(i, j, 0));
982b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang    }
992b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang  }
1002b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang}
1012b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang
1022b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang
1032b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang// TODO(phli): Add test once we have a real vectorized implementation.
1042b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang// static void test_vectorized_concatenation() {}
1052b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang
1062b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wangstatic void test_concatenation_as_lvalue()
1072b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang{
1082b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang  Tensor<int, 2> t1(2, 3);
1092b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang  Tensor<int, 2> t2(2, 3);
1102b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang  t1.setRandom();
1112b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang  t2.setRandom();
1122b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang
1132b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang  Tensor<int, 2> result(4, 3);
1142b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang  result.setRandom();
1152b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang  t1.concatenate(t2, 0) = result;
1162b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang
1172b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang  for (int i = 0; i < 2; ++i) {
1182b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang    for (int j = 0; j < 3; ++j) {
1192b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang      VERIFY_IS_EQUAL(t1(i, j), result(i, j));
1202b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang      VERIFY_IS_EQUAL(t2(i, j), result(i+2, j));
1212b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang    }
1222b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang  }
1232b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang}
1242b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang
1252b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang
1262b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wangvoid test_cxx11_tensor_concatenation()
1272b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang{
1282b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang   CALL_SUBTEST(test_dimension_failures<ColMajor>());
1292b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang   CALL_SUBTEST(test_dimension_failures<RowMajor>());
1302b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang   CALL_SUBTEST(test_static_dimension_failure<ColMajor>());
1312b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang   CALL_SUBTEST(test_static_dimension_failure<RowMajor>());
1322b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang   CALL_SUBTEST(test_simple_concatenation<ColMajor>());
1332b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang   CALL_SUBTEST(test_simple_concatenation<RowMajor>());
1342b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang   // CALL_SUBTEST(test_vectorized_concatenation());
1352b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang   CALL_SUBTEST(test_concatenation_as_lvalue());
1362b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang
1372b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang}
138