1// This file is part of Eigen, a lightweight C++ template library 2// for linear algebra. 3// 4// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com> 5// 6// This Source Code Form is subject to the terms of the Mozilla 7// Public License v. 2.0. If a copy of the MPL was not distributed 8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 9 10#include "main.h" 11 12#include <Eigen/CXX11/Tensor> 13 14using Eigen::Tensor; 15 16template<int DataLayout> 17static void test_dimension_failures() 18{ 19 Tensor<int, 3, DataLayout> left(2, 3, 1); 20 Tensor<int, 3, DataLayout> right(3, 3, 1); 21 left.setRandom(); 22 right.setRandom(); 23 24 // Okay; other dimensions are equal. 25 Tensor<int, 3, DataLayout> concatenation = left.concatenate(right, 0); 26 27 // Dimension mismatches. 28 VERIFY_RAISES_ASSERT(concatenation = left.concatenate(right, 1)); 29 VERIFY_RAISES_ASSERT(concatenation = left.concatenate(right, 2)); 30 31 // Axis > NumDims or < 0. 32 VERIFY_RAISES_ASSERT(concatenation = left.concatenate(right, 3)); 33 VERIFY_RAISES_ASSERT(concatenation = left.concatenate(right, -1)); 34} 35 36template<int DataLayout> 37static void test_static_dimension_failure() 38{ 39 Tensor<int, 2, DataLayout> left(2, 3); 40 Tensor<int, 3, DataLayout> right(2, 3, 1); 41 42#ifdef CXX11_TENSOR_CONCATENATION_STATIC_DIMENSION_FAILURE 43 // Technically compatible, but we static assert that the inputs have same 44 // NumDims. 45 Tensor<int, 3, DataLayout> concatenation = left.concatenate(right, 0); 46#endif 47 48 // This can be worked around in this case. 49 Tensor<int, 3, DataLayout> concatenation = left 50 .reshape(Tensor<int, 3>::Dimensions(2, 3, 1)) 51 .concatenate(right, 0); 52 Tensor<int, 2, DataLayout> alternative = left 53 .concatenate(right.reshape(Tensor<int, 2>::Dimensions{{{2, 3}}}), 0); 54} 55 56template<int DataLayout> 57static void test_simple_concatenation() 58{ 59 Tensor<int, 3, DataLayout> left(2, 3, 1); 60 Tensor<int, 3, DataLayout> right(2, 3, 1); 61 left.setRandom(); 62 right.setRandom(); 63 64 Tensor<int, 3, DataLayout> concatenation = left.concatenate(right, 0); 65 VERIFY_IS_EQUAL(concatenation.dimension(0), 4); 66 VERIFY_IS_EQUAL(concatenation.dimension(1), 3); 67 VERIFY_IS_EQUAL(concatenation.dimension(2), 1); 68 for (int j = 0; j < 3; ++j) { 69 for (int i = 0; i < 2; ++i) { 70 VERIFY_IS_EQUAL(concatenation(i, j, 0), left(i, j, 0)); 71 } 72 for (int i = 2; i < 4; ++i) { 73 VERIFY_IS_EQUAL(concatenation(i, j, 0), right(i - 2, j, 0)); 74 } 75 } 76 77 concatenation = left.concatenate(right, 1); 78 VERIFY_IS_EQUAL(concatenation.dimension(0), 2); 79 VERIFY_IS_EQUAL(concatenation.dimension(1), 6); 80 VERIFY_IS_EQUAL(concatenation.dimension(2), 1); 81 for (int i = 0; i < 2; ++i) { 82 for (int j = 0; j < 3; ++j) { 83 VERIFY_IS_EQUAL(concatenation(i, j, 0), left(i, j, 0)); 84 } 85 for (int j = 3; j < 6; ++j) { 86 VERIFY_IS_EQUAL(concatenation(i, j, 0), right(i, j - 3, 0)); 87 } 88 } 89 90 concatenation = left.concatenate(right, 2); 91 VERIFY_IS_EQUAL(concatenation.dimension(0), 2); 92 VERIFY_IS_EQUAL(concatenation.dimension(1), 3); 93 VERIFY_IS_EQUAL(concatenation.dimension(2), 2); 94 for (int i = 0; i < 2; ++i) { 95 for (int j = 0; j < 3; ++j) { 96 VERIFY_IS_EQUAL(concatenation(i, j, 0), left(i, j, 0)); 97 VERIFY_IS_EQUAL(concatenation(i, j, 1), right(i, j, 0)); 98 } 99 } 100} 101 102 103// TODO(phli): Add test once we have a real vectorized implementation. 104// static void test_vectorized_concatenation() {} 105 106static void test_concatenation_as_lvalue() 107{ 108 Tensor<int, 2> t1(2, 3); 109 Tensor<int, 2> t2(2, 3); 110 t1.setRandom(); 111 t2.setRandom(); 112 113 Tensor<int, 2> result(4, 3); 114 result.setRandom(); 115 t1.concatenate(t2, 0) = result; 116 117 for (int i = 0; i < 2; ++i) { 118 for (int j = 0; j < 3; ++j) { 119 VERIFY_IS_EQUAL(t1(i, j), result(i, j)); 120 VERIFY_IS_EQUAL(t2(i, j), result(i+2, j)); 121 } 122 } 123} 124 125 126void test_cxx11_tensor_concatenation() 127{ 128 CALL_SUBTEST(test_dimension_failures<ColMajor>()); 129 CALL_SUBTEST(test_dimension_failures<RowMajor>()); 130 CALL_SUBTEST(test_static_dimension_failure<ColMajor>()); 131 CALL_SUBTEST(test_static_dimension_failure<RowMajor>()); 132 CALL_SUBTEST(test_simple_concatenation<ColMajor>()); 133 CALL_SUBTEST(test_simple_concatenation<RowMajor>()); 134 // CALL_SUBTEST(test_vectorized_concatenation()); 135 CALL_SUBTEST(test_concatenation_as_lvalue()); 136 137} 138