12b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang#include "main.h" 22b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang 32b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang#include <Eigen/CXX11/Tensor> 42b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang 52b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wangusing Eigen::Tensor; 62b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wangusing Eigen::RowMajor; 72b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang 82b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wangstatic void test_comparison_sugar() { 92b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang // we already trust comparisons between tensors, we're simply checking that 102b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang // the sugared versions are doing the same thing 112b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang Tensor<int, 3> t(6, 7, 5); 122b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang 132b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang t.setRandom(); 142b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang // make sure we have at least one value == 0 152b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang t(0,0,0) = 0; 162b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang 172b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang Tensor<bool,0> b; 182b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang 192b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang#define TEST_TENSOR_EQUAL(e1, e2) \ 202b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang b = ((e1) == (e2)).all(); \ 212b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang VERIFY(b()) 222b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang 232b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang#define TEST_OP(op) TEST_TENSOR_EQUAL(t op 0, t op t.constant(0)) 242b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang 252b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang TEST_OP(==); 262b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang TEST_OP(!=); 272b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang TEST_OP(<=); 282b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang TEST_OP(>=); 292b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang TEST_OP(<); 302b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang TEST_OP(>); 312b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang#undef TEST_OP 322b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang#undef TEST_TENSOR_EQUAL 332b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang} 342b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang 352b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang 362b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wangstatic void test_scalar_sugar_add_mul() { 372b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang Tensor<float, 3> A(6, 7, 5); 382b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang Tensor<float, 3> B(6, 7, 5); 392b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang A.setRandom(); 402b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang B.setRandom(); 412b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang 422b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang const float alpha = 0.43f; 432b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang const float beta = 0.21f; 442b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang const float gamma = 0.14f; 452b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang 462b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang Tensor<float, 3> R = A.constant(gamma) + A * A.constant(alpha) + B * B.constant(beta); 472b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang Tensor<float, 3> S = A * alpha + B * beta + gamma; 482b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang Tensor<float, 3> T = gamma + alpha * A + beta * B; 492b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang 502b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang for (int i = 0; i < 6*7*5; ++i) { 512b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang VERIFY_IS_APPROX(R(i), S(i)); 522b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang VERIFY_IS_APPROX(R(i), T(i)); 532b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang } 542b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang} 552b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang 562b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wangstatic void test_scalar_sugar_sub_div() { 572b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang Tensor<float, 3> A(6, 7, 5); 582b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang Tensor<float, 3> B(6, 7, 5); 592b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang A.setRandom(); 602b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang B.setRandom(); 612b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang 622b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang const float alpha = 0.43f; 632b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang const float beta = 0.21f; 642b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang const float gamma = 0.14f; 652b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang const float delta = 0.32f; 662b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang 672b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang Tensor<float, 3> R = A.constant(gamma) - A / A.constant(alpha) 682b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang - B.constant(beta) / B - A.constant(delta); 692b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang Tensor<float, 3> S = gamma - A / alpha - beta / B - delta; 702b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang 712b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang for (int i = 0; i < 6*7*5; ++i) { 722b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang VERIFY_IS_APPROX(R(i), S(i)); 732b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang } 742b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang} 752b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang 762b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wangvoid test_cxx11_tensor_sugar() 772b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang{ 782b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang CALL_SUBTEST(test_comparison_sugar()); 792b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang CALL_SUBTEST(test_scalar_sugar_add_mul()); 802b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang CALL_SUBTEST(test_scalar_sugar_sub_div()); 812b8756b6f1de65d3f8bffab45be6c44ceb7411fcMiao Wang} 82