1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#include "main.h"
11
12#include <Eigen/CXX11/Tensor>
13
14using Eigen::Tensor;
15
16template<int DataLayout>
17static void test_dimension_failures()
18{
19  Tensor<int, 3, DataLayout> left(2, 3, 1);
20  Tensor<int, 3, DataLayout> right(3, 3, 1);
21  left.setRandom();
22  right.setRandom();
23
24  // Okay; other dimensions are equal.
25  Tensor<int, 3, DataLayout> concatenation = left.concatenate(right, 0);
26
27  // Dimension mismatches.
28  VERIFY_RAISES_ASSERT(concatenation = left.concatenate(right, 1));
29  VERIFY_RAISES_ASSERT(concatenation = left.concatenate(right, 2));
30
31  // Axis > NumDims or < 0.
32  VERIFY_RAISES_ASSERT(concatenation = left.concatenate(right, 3));
33  VERIFY_RAISES_ASSERT(concatenation = left.concatenate(right, -1));
34}
35
36template<int DataLayout>
37static void test_static_dimension_failure()
38{
39  Tensor<int, 2, DataLayout> left(2, 3);
40  Tensor<int, 3, DataLayout> right(2, 3, 1);
41
42#ifdef CXX11_TENSOR_CONCATENATION_STATIC_DIMENSION_FAILURE
43  // Technically compatible, but we static assert that the inputs have same
44  // NumDims.
45  Tensor<int, 3, DataLayout> concatenation = left.concatenate(right, 0);
46#endif
47
48  // This can be worked around in this case.
49  Tensor<int, 3, DataLayout> concatenation = left
50      .reshape(Tensor<int, 3>::Dimensions(2, 3, 1))
51      .concatenate(right, 0);
52  Tensor<int, 2, DataLayout> alternative = left
53      .concatenate(right.reshape(Tensor<int, 2>::Dimensions{{{2, 3}}}), 0);
54}
55
56template<int DataLayout>
57static void test_simple_concatenation()
58{
59  Tensor<int, 3, DataLayout> left(2, 3, 1);
60  Tensor<int, 3, DataLayout> right(2, 3, 1);
61  left.setRandom();
62  right.setRandom();
63
64  Tensor<int, 3, DataLayout> concatenation = left.concatenate(right, 0);
65  VERIFY_IS_EQUAL(concatenation.dimension(0), 4);
66  VERIFY_IS_EQUAL(concatenation.dimension(1), 3);
67  VERIFY_IS_EQUAL(concatenation.dimension(2), 1);
68  for (int j = 0; j < 3; ++j) {
69    for (int i = 0; i < 2; ++i) {
70      VERIFY_IS_EQUAL(concatenation(i, j, 0), left(i, j, 0));
71    }
72    for (int i = 2; i < 4; ++i) {
73      VERIFY_IS_EQUAL(concatenation(i, j, 0), right(i - 2, j, 0));
74    }
75  }
76
77  concatenation = left.concatenate(right, 1);
78  VERIFY_IS_EQUAL(concatenation.dimension(0), 2);
79  VERIFY_IS_EQUAL(concatenation.dimension(1), 6);
80  VERIFY_IS_EQUAL(concatenation.dimension(2), 1);
81  for (int i = 0; i < 2; ++i) {
82    for (int j = 0; j < 3; ++j) {
83      VERIFY_IS_EQUAL(concatenation(i, j, 0), left(i, j, 0));
84    }
85    for (int j = 3; j < 6; ++j) {
86      VERIFY_IS_EQUAL(concatenation(i, j, 0), right(i, j - 3, 0));
87    }
88  }
89
90  concatenation = left.concatenate(right, 2);
91  VERIFY_IS_EQUAL(concatenation.dimension(0), 2);
92  VERIFY_IS_EQUAL(concatenation.dimension(1), 3);
93  VERIFY_IS_EQUAL(concatenation.dimension(2), 2);
94  for (int i = 0; i < 2; ++i) {
95    for (int j = 0; j < 3; ++j) {
96      VERIFY_IS_EQUAL(concatenation(i, j, 0), left(i, j, 0));
97      VERIFY_IS_EQUAL(concatenation(i, j, 1), right(i, j, 0));
98    }
99  }
100}
101
102
103// TODO(phli): Add test once we have a real vectorized implementation.
104// static void test_vectorized_concatenation() {}
105
106static void test_concatenation_as_lvalue()
107{
108  Tensor<int, 2> t1(2, 3);
109  Tensor<int, 2> t2(2, 3);
110  t1.setRandom();
111  t2.setRandom();
112
113  Tensor<int, 2> result(4, 3);
114  result.setRandom();
115  t1.concatenate(t2, 0) = result;
116
117  for (int i = 0; i < 2; ++i) {
118    for (int j = 0; j < 3; ++j) {
119      VERIFY_IS_EQUAL(t1(i, j), result(i, j));
120      VERIFY_IS_EQUAL(t2(i, j), result(i+2, j));
121    }
122  }
123}
124
125
126void test_cxx11_tensor_concatenation()
127{
128   CALL_SUBTEST(test_dimension_failures<ColMajor>());
129   CALL_SUBTEST(test_dimension_failures<RowMajor>());
130   CALL_SUBTEST(test_static_dimension_failure<ColMajor>());
131   CALL_SUBTEST(test_static_dimension_failure<RowMajor>());
132   CALL_SUBTEST(test_simple_concatenation<ColMajor>());
133   CALL_SUBTEST(test_simple_concatenation<RowMajor>());
134   // CALL_SUBTEST(test_vectorized_concatenation());
135   CALL_SUBTEST(test_concatenation_as_lvalue());
136
137}
138