math_grad_test.cc revision a373b1f74215e44920bf9362a51bece530edf88a
1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include "tensorflow/cc/framework/grad_op_registry.h" 17#include "tensorflow/cc/framework/gradient_checker.h" 18#include "tensorflow/cc/framework/testutil.h" 19#include "tensorflow/cc/gradients/grad_testutil.h" 20#include "tensorflow/cc/ops/standard_ops.h" 21#include "tensorflow/core/framework/tensor_testutil.h" 22#include "tensorflow/core/lib/core/status_test_util.h" 23#include "tensorflow/core/lib/random/random.h" 24 25namespace tensorflow { 26using namespace ops; // NOLINT(build/namespaces) 27 28namespace { 29 30// TODO(andydavis) Test gradient function against numeric gradients output. 31// TODO(andydavis) As more gradients are added move common test functions 32// to a testutil library. 33 34class CWiseUnaryGradTest : public ::testing::Test { 35 protected: 36 CWiseUnaryGradTest() : scope_(Scope::NewRootScope().WithDevice("/cpu:0")) {} 37 38 enum UnaryOpType { 39 ABS, 40 NEG, 41 INV, 42 SQUARE, 43 SQRT, 44 RSQRT, 45 EXP, 46 EXPM1, 47 LOG, 48 LOG1P, 49 SINH, 50 COSH, 51 TANH, 52 ASINH, 53 ACOSH, 54 ATANH, 55 SIGMOID, 56 SIGN, 57 SIN, 58 COS, 59 ASIN, 60 ACOS, 61 TAN, 62 ATAN, 63 REAL, 64 IMAG, 65 CONJ, 66 COMPLEX, 67 ANGLE 68 }; 69 70 template <typename X_T, typename Y_T> 71 void TestCWiseGrad(UnaryOpType op_type, const std::function<X_T(int)>& x_fn) { 72 TF_ASSERT_OK(scope_.status()); 73 DataType x_type = DataTypeToEnum<X_T>::v(); 74 TensorShape shape({2, 3, 2}); 75 auto x = Placeholder(scope_, x_type, Placeholder::Shape(shape)); 76 Tensor x_data(x_type, shape); 77 auto x_data_flat = x_data.flat<X_T>(); 78 for (int i = 0; i < x_data_flat.size(); ++i) { 79 x_data_flat(i) = x_fn(i); 80 } 81 82 Output y; 83 switch (op_type) { 84 case ABS: 85 y = Abs(scope_, x); 86 break; 87 case NEG: 88 y = Neg(scope_, x); 89 break; 90 case INV: 91 y = Reciprocal(scope_, x); 92 break; 93 case SQUARE: 94 y = Square(scope_, x); 95 break; 96 case SQRT: 97 y = Sqrt(scope_, x); 98 break; 99 case RSQRT: 100 y = Rsqrt(scope_, x); 101 break; 102 case EXP: 103 y = Exp(scope_, x); 104 break; 105 case EXPM1: 106 y = Expm1(scope_, x); 107 break; 108 case LOG: 109 y = Log(scope_, x); 110 break; 111 case LOG1P: 112 y = Log1p(scope_, x); 113 break; 114 case SINH: 115 y = Sinh(scope_, x); 116 break; 117 case COSH: 118 y = Cosh(scope_, x); 119 break; 120 case TANH: 121 y = Tanh(scope_, x); 122 break; 123 case ASINH: 124 y = Asinh(scope_, x); 125 break; 126 case ACOSH: 127 y = Acosh(scope_, x); 128 break; 129 case ATANH: 130 y = Atanh(scope_, x); 131 break; 132 case SIGMOID: 133 y = Sigmoid(scope_, x); 134 break; 135 case SIGN: 136 y = Sign(scope_, x); 137 break; 138 case SIN: 139 y = Sin(scope_, x); 140 break; 141 case COS: 142 y = Cos(scope_, x); 143 break; 144 case ASIN: 145 y = Asin(scope_, x); 146 break; 147 case ACOS: 148 y = Acos(scope_, x); 149 break; 150 case TAN: 151 y = Tan(scope_, x); 152 break; 153 case ATAN: 154 y = Atan(scope_, x); 155 break; 156 case REAL: 157 y = Real(scope_, x); 158 break; 159 case IMAG: 160 y = Imag(scope_, x); 161 break; 162 case CONJ: 163 y = Conj(scope_, x); 164 break; 165 case COMPLEX: 166 y = Complex(scope_, x, x); 167 break; 168 case ANGLE: 169 y = Angle(scope_, x); 170 break; 171 } 172 173 float max_error; 174 TF_ASSERT_OK((ComputeGradientError<X_T, Y_T, float>(scope_, x, x_data, y, 175 shape, &max_error))); 176 EXPECT_LT(max_error, 1e-3f); 177 } 178 179 float RV(const std::vector<float>& v) { 180 return v[random::New64() % v.size()]; 181 } 182 183 complex64 CRV(const std::vector<complex64>& v) { 184 return v[random::New64() % v.size()]; 185 } 186 187 complex64 conjugate(const complex64& val) { 188 return complex64(val.real(), -val.imag()); 189 } 190 191 Scope scope_; 192}; 193 194TEST_F(CWiseUnaryGradTest, Abs) { 195 auto x_fn = [this](const int i) { return RV({-1, 0, 1}); }; 196 TestCWiseGrad<float, float>(ABS, x_fn); 197} 198 199TEST_F(CWiseUnaryGradTest, Neg) { 200 auto x_fn = [this](const int i) { return RV({-1, 0, 1}); }; 201 TestCWiseGrad<float, float>(NEG, x_fn); 202} 203 204TEST_F(CWiseUnaryGradTest, Reciprocal) { 205 auto x_fn = [this](const int i) { return RV({-1, 1, -2, 2, -3, 3, -4, 4}); }; 206 TestCWiseGrad<float, float>(INV, x_fn); 207} 208 209TEST_F(CWiseUnaryGradTest, Reciprocal_Complex) { 210 auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); }; 211 TestCWiseGrad<complex64, complex64>(INV, x_fn); 212} 213 214TEST_F(CWiseUnaryGradTest, Square) { 215 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 216 TestCWiseGrad<float, float>(SQUARE, x_fn); 217} 218 219TEST_F(CWiseUnaryGradTest, Square_Complex) { 220 auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); }; 221 TestCWiseGrad<complex64, complex64>(SQUARE, x_fn); 222} 223 224TEST_F(CWiseUnaryGradTest, Sqrt) { 225 auto x_fn = [this](const int i) { return RV({0.5, 1, 2, 3, 4, 5, 6, 7}); }; 226 TestCWiseGrad<float, float>(SQRT, x_fn); 227} 228 229TEST_F(CWiseUnaryGradTest, Sqrt_Complex) { 230 auto x_fn = [this](const int i) { 231 return CRV({{-1.0f, 0.5f}, {1.0f, 0.5f}, {2, -1}}); 232 }; 233 TestCWiseGrad<complex64, complex64>(SQRT, x_fn); 234} 235 236TEST_F(CWiseUnaryGradTest, Rsqrt) { 237 auto x_fn = [this](const int i) { return RV({1, 2, 3, 4, 5, 6, 7, 8}); }; 238 TestCWiseGrad<float, float>(RSQRT, x_fn); 239} 240 241TEST_F(CWiseUnaryGradTest, Rsqrt_Complex) { 242 auto x_fn = [this](const int i) { 243 return CRV({{-1.0f, 0.5f}, {1.0f, 0.5f}, {2, -1}}); 244 }; 245 TestCWiseGrad<complex64, complex64>(RSQRT, x_fn); 246} 247 248TEST_F(CWiseUnaryGradTest, Exp) { 249 auto x_fn = [this](const int i) { 250 return RV({0, -1, 1, -1.5f, 1.5f, -2, 2}); 251 }; 252 TestCWiseGrad<float, float>(EXP, x_fn); 253} 254 255TEST_F(CWiseUnaryGradTest, Exp_Complex) { 256 auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); }; 257 TestCWiseGrad<complex64, complex64>(EXP, x_fn); 258} 259 260TEST_F(CWiseUnaryGradTest, Expm1) { 261 auto x_fn = [this](const int i) { return RV({0, -1, 1e-6, 1, -1.5, 1.5}); }; 262 TestCWiseGrad<float, float>(EXPM1, x_fn); 263} 264 265TEST_F(CWiseUnaryGradTest, Expm1_Complex) { 266 auto x_fn = [this](const int i) { 267 return CRV({{-1, 0}, {1, 0}, {1.5, -1.5}}); 268 }; 269 TestCWiseGrad<complex64, complex64>(EXPM1, x_fn); 270} 271 272TEST_F(CWiseUnaryGradTest, Log) { 273 auto x_fn = [this](const int i) { return RV({0.5, 1, 2, 3, 4}); }; 274 TestCWiseGrad<float, float>(LOG, x_fn); 275} 276 277TEST_F(CWiseUnaryGradTest, Log_Complex) { 278 auto x_fn = [this](const int i) { 279 return CRV({{-1, 0.5f}, {1, 0.5f}, {2, -1}}); 280 }; 281 TestCWiseGrad<complex64, complex64>(LOG, x_fn); 282} 283 284TEST_F(CWiseUnaryGradTest, Log1p) { 285 auto x_fn = [this](const int i) { return RV({0, 1e-6, 1, 2, 3, 4, 100}); }; 286 TestCWiseGrad<float, float>(LOG1P, x_fn); 287} 288 289TEST_F(CWiseUnaryGradTest, Log1p_Complex) { 290 auto x_fn = [this](const int i) { 291 return CRV({{0, 0}, {1e-6, 0}, {2, -1}, {1, 2}, {3, 4}}); 292 }; 293 TestCWiseGrad<complex64, complex64>(LOG1P, x_fn); 294} 295 296TEST_F(CWiseUnaryGradTest, Sinh) { 297 auto x_fn = [this](const int i) { return RV({0.5, -0.5, 1, -1, 1.5, -1.5}); }; 298 TestCWiseGrad<float, float>(SINH, x_fn); 299} 300 301TEST_F(CWiseUnaryGradTest, Sinh_Complex) { 302 auto x_fn = [this](const int i) { 303 return CRV({{0.5, 0.25}, {0.25, 0.5}, {1.5, -1}, {1, 1.5}}); 304 }; 305 TestCWiseGrad<complex64, complex64>(SINH, x_fn); 306} 307 308TEST_F(CWiseUnaryGradTest, Cosh) { 309 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 310 TestCWiseGrad<float, float>(COSH, x_fn); 311} 312 313TEST_F(CWiseUnaryGradTest, Cosh_Complex) { 314 auto x_fn = [this](const int i) { 315 return CRV({{0.5, 0.25}, {0.25, 0.5}, {1.5, -1}, {1, 1.5}}); 316 }; 317 TestCWiseGrad<complex64, complex64>(COSH, x_fn); 318} 319 320TEST_F(CWiseUnaryGradTest, Tanh) { 321 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 322 TestCWiseGrad<float, float>(TANH, x_fn); 323} 324 325TEST_F(CWiseUnaryGradTest, Tanh_Complex) { 326 auto x_fn = [this](const int i) { 327 return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); 328 }; 329 TestCWiseGrad<complex64, complex64>(TANH, x_fn); 330} 331 332TEST_F(CWiseUnaryGradTest, Asinh) { 333 auto x_fn = [this](const int i) { return RV({0.5, 1, -1, -1.5, 1.5}); }; 334 TestCWiseGrad<float, float>(ASINH, x_fn); 335} 336 337TEST_F(CWiseUnaryGradTest, Asinh_Complex) { 338 auto x_fn = [this](const int i) { 339 return CRV({{1, 0.5}, {0.5, 1}, {0.5, -1}, {1, 1.5}}); 340 }; 341 TestCWiseGrad<complex64, complex64>(ASINH, x_fn); 342} 343 344TEST_F(CWiseUnaryGradTest, Acosh) { 345 auto x_fn = [this](const int i) { return RV({1.5, 2, 2.5}); }; 346 TestCWiseGrad<float, float>(ACOSH, x_fn); 347} 348 349TEST_F(CWiseUnaryGradTest, Acosh_Complex) { 350 auto x_fn = [this](const int i) { 351 return CRV({{1, 0.5}, {0.5, 1}, {0.5, -1}, {1, 1.5}}); 352 }; 353 TestCWiseGrad<complex64, complex64>(ACOSH, x_fn); 354} 355 356TEST_F(CWiseUnaryGradTest, Atanh) { 357 auto x_fn = [this](const int i) { return RV({0, -0.5, 0.5, -0.1, 0.1}); }; 358 TestCWiseGrad<float, float>(ATANH, x_fn); 359} 360 361TEST_F(CWiseUnaryGradTest, Atanh_Complex) { 362 auto x_fn = [this](const int i) { 363 return CRV({{0.1, 0}, {0, 0.1}, {0.2, -0.1}, {0.1, 0.2}, {0.3, 0.4}}); 364 }; 365 TestCWiseGrad<complex64, complex64>(ATANH, x_fn); 366} 367 368TEST_F(CWiseUnaryGradTest, Sigmoid) { 369 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 370 TestCWiseGrad<float, float>(SIGMOID, x_fn); 371} 372 373TEST_F(CWiseUnaryGradTest, Sigmoid_Complex) { 374 auto x_fn = [this](const int i) { 375 return CRV({{1, 0}, {0, 0}, {2, -1}, {1, 2}, {3, 4}}); 376 }; 377 TestCWiseGrad<complex64, complex64>(SIGMOID, x_fn); 378} 379 380TEST_F(CWiseUnaryGradTest, Sign) { 381 auto x_fn = [this](const int i) { return RV({-1, 1, -2, 2, -3, 3}); }; 382 TestCWiseGrad<float, float>(SIGN, x_fn); 383} 384 385TEST_F(CWiseUnaryGradTest, Sin) { 386 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 387 TestCWiseGrad<float, float>(SIN, x_fn); 388} 389 390TEST_F(CWiseUnaryGradTest, Sin_Complex) { 391 auto x_fn = [this](const int i) { 392 return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}}); 393 }; 394 TestCWiseGrad<complex64, complex64>(SIN, x_fn); 395} 396 397TEST_F(CWiseUnaryGradTest, Cos) { 398 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 399 TestCWiseGrad<float, float>(COS, x_fn); 400} 401 402TEST_F(CWiseUnaryGradTest, Cos_Complex) { 403 auto x_fn = [this](const int i) { 404 return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}}); 405 }; 406 TestCWiseGrad<complex64, complex64>(COS, x_fn); 407} 408 409TEST_F(CWiseUnaryGradTest, Asin) { 410 auto x_fn = [this](const int i) { return RV({0, 0.25, -0.25, -0.5, 0.5}); }; 411 TestCWiseGrad<float, float>(ASIN, x_fn); 412} 413 414TEST_F(CWiseUnaryGradTest, Asin_Complex) { 415 auto x_fn = [this](const int i) { 416 return CRV({{0.5, 0}, {0, 0.5}, {0.25, -0.75}, {0.5, 0.25}}); 417 }; 418 // TODO(kbsriram) 419 // Enable test when the asin kernel supports complex numbers 420 if (false) { 421 TestCWiseGrad<complex64, complex64>(ASIN, x_fn); 422 } 423} 424 425TEST_F(CWiseUnaryGradTest, Acos) { 426 auto x_fn = [this](const int i) { return RV({0, -0.5, 0.5, -0.75, 0.75}); }; 427 TestCWiseGrad<float, float>(ACOS, x_fn); 428} 429 430TEST_F(CWiseUnaryGradTest, Acos_Complex) { 431 auto x_fn = [this](const int i) { 432 return CRV({{0.5, 0}, {0, 0.5}, {0.25, -0.75}, {0.5, 0.25}}); 433 }; 434 // TODO(kbsriram) 435 // Add test when the acos kernel supports complex numbers 436 if (false) { 437 TestCWiseGrad<complex64, complex64>(ACOS, x_fn); 438 } 439} 440 441TEST_F(CWiseUnaryGradTest, Tan) { 442 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 443 TestCWiseGrad<float, float>(TAN, x_fn); 444} 445 446TEST_F(CWiseUnaryGradTest, Tan_Complex) { 447 auto x_fn = [this](const int i) { 448 return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); 449 }; 450 // TODO(kbsriram) 451 // Enable when tan kernel supports complex inputs 452 if (false) { 453 TestCWiseGrad<complex64, complex64>(TAN, x_fn); 454 } 455} 456 457TEST_F(CWiseUnaryGradTest, Atan) { 458 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 459 TestCWiseGrad<float, float>(ATAN, x_fn); 460} 461 462TEST_F(CWiseUnaryGradTest, Atan_Complex) { 463 auto x_fn = [this](const int i) { 464 return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); 465 }; 466 // TODO(kbsriram) 467 // Add test when the atan kernel supports complex numbers 468 if (false) { 469 TestCWiseGrad<complex64, complex64>(ATAN, x_fn); 470 } 471} 472 473TEST_F(CWiseUnaryGradTest, Real) { 474 auto x_fn = [this](const int i) { 475 return CRV({{1, -1}, {-2, 2}, {2, 3}, {-2, -3}}); 476 }; 477 TestCWiseGrad<complex64, float>(REAL, x_fn); 478} 479 480TEST_F(CWiseUnaryGradTest, Imag) { 481 auto x_fn = [this](const int i) { 482 return CRV({{1, -1}, {-2, 2}, {2, 3}, {-2, -3}}); 483 }; 484 TestCWiseGrad<complex64, float>(IMAG, x_fn); 485} 486 487TEST_F(CWiseUnaryGradTest, Conj) { 488 auto x_fn = [this](const int i) { 489 return CRV({{1, -1}, {-2, 2}, {2, 3}, {-2, -3}}); 490 }; 491 TestCWiseGrad<complex64, complex64>(CONJ, x_fn); 492} 493 494TEST_F(CWiseUnaryGradTest, Complex) { 495 auto x_fn = [this](const int i) { return RV({1, -1, 2, -2, 3, -3}); }; 496 TestCWiseGrad<float, complex64>(COMPLEX, x_fn); 497} 498 499TEST_F(CWiseUnaryGradTest, Angle) { 500 auto x_fn = [this](const int i) { 501 return CRV({{1.5, 1.5}, {1.5, -1.5}, {-1.5, 1.5}, {-1.5, -1.5}}); 502 }; 503 TestCWiseGrad<complex64, float>(ANGLE, x_fn); 504} 505 506class MathGradTest : public ::testing::Test { 507 protected: 508 MathGradTest() : root_(Scope::NewRootScope().WithDevice("/cpu:0")) {} 509 510 template <typename T> 511 void TestMatMulGrad(const bool is_batch, const bool t_x, const bool t_y) { 512 TF_ASSERT_OK(root_.status()); 513 // Generate random (but compatible) shapes for matrix multiplication. 514 std::vector<TensorShape> shapes; 515 RandMatMulShapes(is_batch, t_x, t_y, &shapes); 516 TensorShape x_shape = shapes[0]; 517 TensorShape y_shape = shapes[1]; 518 TensorShape z_shape = shapes[2]; 519 auto x = 520 Placeholder(root_, DataTypeToEnum<T>::v(), Placeholder::Shape(x_shape)); 521 auto y = 522 Placeholder(root_, DataTypeToEnum<T>::v(), Placeholder::Shape(y_shape)); 523 Output z; 524 if (is_batch) { 525 z = BatchMatMul(root_, x, y, BatchMatMul::AdjX(t_x).AdjY(t_y)); 526 } else { 527 z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y)); 528 } 529 530 float max_error; 531 TF_ASSERT_OK((ComputeGradientError<T, T, float>( 532 root_, {x, y}, {x_shape, y_shape}, {z}, {z_shape}, &max_error))); 533 EXPECT_LT(max_error, 1e-3); 534 } 535 536 void RandMatMulShapes(const bool is_batch, const bool tx, const bool ty, 537 std::vector<TensorShape>* shapes) { 538 // Choose a random batch size in [1, 4] 539 const int b = 1 + (random::New64() % 4); 540 // z = MatMul(x, y) 541 const int m = Rand(); 542 const int k = Rand(); 543 const int n = Rand(); 544 545 TensorShape x_shape; 546 if (is_batch) { 547 // x.shape = [b, m, k] 548 x_shape = tx ? TensorShape({b, k, m}) : TensorShape({b, m, k}); 549 } else { 550 // x.shape = [m, k] 551 x_shape = tx ? TensorShape({k, m}) : TensorShape({m, k}); 552 } 553 shapes->push_back(x_shape); 554 555 TensorShape y_shape; 556 if (is_batch) { 557 // y.shape = [b, k, n] 558 y_shape = ty ? TensorShape({b, n, k}) : TensorShape({b, k, n}); 559 } else { 560 // y.shape = [k, n] 561 y_shape = ty ? TensorShape({n, k}) : TensorShape({k, n}); 562 } 563 shapes->push_back(y_shape); 564 565 TensorShape z_shape; 566 if (is_batch) { 567 // z.shape = [b, m, n] 568 z_shape = TensorShape({b, m, n}); 569 } else { 570 // z.shape = [m, n] 571 z_shape = TensorShape({m, n}); 572 } 573 shapes->push_back(z_shape); 574 } 575 576 int Rand() { return 1 + (random::New64() % 10); } 577 578 Scope root_; 579}; 580 581TEST_F(MathGradTest, MatMulGrad_NoTranspose) { 582 TestMatMulGrad<float>(false, false, false); 583} 584 585TEST_F(MathGradTest, MatMulComplexGrad_NoTranspose) { 586 TestMatMulGrad<complex64>(false, false, false); 587} 588 589TEST_F(MathGradTest, MatMulGrad_TransposeX) { 590 TestMatMulGrad<float>(false, true, false); 591} 592 593TEST_F(MathGradTest, MatMulComplexGrad_TransposeX) { 594 TestMatMulGrad<complex64>(false, true, false); 595} 596 597TEST_F(MathGradTest, MatMulGrad_TransposeY) { 598 TestMatMulGrad<float>(false, false, true); 599} 600 601TEST_F(MathGradTest, MatMulComplexGrad_TransposeY) { 602 TestMatMulGrad<complex64>(false, false, true); 603} 604 605TEST_F(MathGradTest, MatMulGrad_TransposeX_TransposeY) { 606 TestMatMulGrad<float>(false, true, true); 607} 608 609TEST_F(MathGradTest, MatMulComplexGrad_TransposeX_TransposeY) { 610 TestMatMulGrad<complex64>(false, true, true); 611} 612 613TEST_F(MathGradTest, BatchMatMulGrad_NoTranspose) { 614 TestMatMulGrad<float>(true, false, false); 615} 616 617TEST_F(MathGradTest, BatchMatMulComplexGrad_NoTranspose) { 618 TestMatMulGrad<complex64>(true, false, false); 619} 620 621TEST_F(MathGradTest, BatchMatMulGrad_TransposeX) { 622 TestMatMulGrad<float>(true, true, false); 623} 624 625TEST_F(MathGradTest, BatchMatMulComplexGrad_TransposeX) { 626 TestMatMulGrad<complex64>(true, true, false); 627} 628 629TEST_F(MathGradTest, BatchMatMulGrad_TransposeY) { 630 TestMatMulGrad<float>(true, false, true); 631} 632 633TEST_F(MathGradTest, BatchMatMulComplexGrad_TransposeY) { 634 TestMatMulGrad<complex64>(true, false, true); 635} 636 637TEST_F(MathGradTest, BatchMatMulGrad_TransposeX_TransposeY) { 638 TestMatMulGrad<float>(true, true, true); 639} 640 641TEST_F(MathGradTest, BatchMatMulComplexGrad_TransposeX_TransposeY) { 642 TestMatMulGrad<complex64>(true, true, true); 643} 644 645class NaryGradTest : public ::testing::Test { 646 protected: 647 NaryGradTest() : scope_(Scope::NewRootScope().WithDevice("/cpu:0")) {} 648 649 void RunTest(const OutputList& xs, const std::vector<TensorShape>& x_shapes, 650 const OutputList& ys, const std::vector<TensorShape>& y_shapes) { 651 TF_ASSERT_OK(scope_.status()); 652 float max_error; 653 TF_ASSERT_OK((ComputeGradientError<float, float, float>( 654 scope_, xs, x_shapes, ys, y_shapes, &max_error))); 655 EXPECT_LT(max_error, 1e-3); 656 } 657 658 void RunTest(const Output& x, const Tensor& x_init_value, const Output& y, 659 const TensorShape& y_shape) { 660 TF_ASSERT_OK(scope_.status()); 661 float max_error; 662 TF_ASSERT_OK((ComputeGradientError<float, float, float>( 663 scope_, x, x_init_value, y, y_shape, &max_error))); 664 EXPECT_LT(max_error, 1e-3); 665 } 666 667 Scope scope_; 668}; 669 670TEST_F(NaryGradTest, Sum) { 671 TensorShape x_shape({2, 3, 5, 7}); 672 auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); 673 auto y = Sum(scope_, x, {1, -1}); 674 // y's shape is the result of reducing x along axes 1 and -1 (= 3) 675 TensorShape y_shape({2, 5}); 676 RunTest({x}, {x_shape}, {y}, {y_shape}); 677} 678 679TEST_F(NaryGradTest, Mean) { 680 TensorShape x_shape({2, 3, 5, 7}); 681 auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); 682 auto y = Mean(scope_, x, {1, -1}); 683 // y's shape is the result of reducing x along axes 1 and -1 (= 3) 684 TensorShape y_shape({2, 5}); 685 RunTest({x}, {x_shape}, {y}, {y_shape}); 686} 687 688TEST_F(NaryGradTest, Min) { 689 TensorShape x_shape({2, 3}); 690 auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); 691 auto y = Min(scope_, x, {-1}); 692 // y's shape is the result of reducing x along axes -1 (= 1) 693 TensorShape y_shape({2}); 694 Tensor x_init_value = 695 test::AsTensor<float>({0.5f, 0.7f, 0.2f, 1.0f, 1.5f, -2.8f}, x_shape); 696 RunTest(x, x_init_value, y, y_shape); 697} 698 699TEST_F(NaryGradTest, Max) { 700 TensorShape x_shape({2, 3}); 701 auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); 702 auto y = Max(scope_, x, {-1}); 703 // y's shape is the result of reducing x along axes -1 (= 1) 704 TensorShape y_shape({2}); 705 Tensor x_init_value = 706 test::AsTensor<float>({0.5f, 0.7f, 0.2f, 1.0f, 1.5f, -2.8f}, x_shape); 707 RunTest(x, x_init_value, y, y_shape); 708} 709 710TEST_F(NaryGradTest, MinMulti) { 711 // Test gradient when there are multiple minima. 712 // Note that we cannot directly use a test Tensor with multiple 713 // minima, as the numeric estimator will calculate incorrect 714 // gradients when perturbing each entry in the Tensor (which then 715 // changes how many minima exist.) 716 // Instead, we use a single input that broadcast-multiplies a larger 717 // tensor with equal values, and apply reduce_min to the multiplied 718 // result. 719 TensorShape x_shape({1}); 720 auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); 721 auto all_same = Mul(scope_, Const(scope_, {1.f, 1.f, 1.f}), x); 722 auto y = Min(scope_, all_same, {0}); 723 // y is a [3] shaped tensor reduced along dimension 0, so it is [1] shaped 724 TensorShape y_shape({1}); 725 RunTest({x}, {x_shape}, {y}, {y_shape}); 726} 727 728TEST_F(NaryGradTest, MaxMulti) { 729 TensorShape x_shape({1}); 730 auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); 731 auto all_same = Mul(scope_, Const(scope_, {1.f, 1.f, 1.f}), x); 732 auto y = Max(scope_, all_same, {0}); 733 TensorShape y_shape({1}); 734 RunTest({x}, {x_shape}, {y}, {y_shape}); 735} 736 737TEST_F(NaryGradTest, AddN) { 738 TensorShape shape({3, 2, 5}); 739 std::vector<Output> xs; 740 xs.push_back(Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape))); 741 xs.push_back(Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape))); 742 xs.push_back(Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape))); 743 auto y = AddN(scope_, xs); 744 RunTest(xs, {shape, shape, shape}, {y}, {shape}); 745} 746 747TEST_F(NaryGradTest, Add) { 748 TensorShape x1_shape({3, 2, 5}); 749 TensorShape x2_shape({2, 5}); 750 auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x1_shape)); 751 auto x2 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x2_shape)); 752 auto y = Add(scope_, x1, x2); 753 RunTest({x1, x2}, {x1_shape, x2_shape}, {y}, {x1_shape}); 754} 755 756TEST_F(NaryGradTest, Sub) { 757 TensorShape x1_shape({3, 2, 5}); 758 TensorShape x2_shape({2, 5}); 759 auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x1_shape)); 760 auto x2 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x2_shape)); 761 auto y = Sub(scope_, x1, x2); 762 RunTest({x1, x2}, {x1_shape, x2_shape}, {y}, {x1_shape}); 763} 764 765TEST_F(NaryGradTest, Mul) { 766 TensorShape x1_shape({3, 2, 5}); 767 TensorShape x2_shape({2, 5}); 768 auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x1_shape)); 769 auto x2 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x2_shape)); 770 auto y = Mul(scope_, x1, x2); 771 RunTest({x1, x2}, {x1_shape, x2_shape}, {y}, {x1_shape}); 772} 773 774TEST_F(NaryGradTest, Div) { 775 TensorShape x_shape({3, 2, 5}); 776 auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); 777 // Test x / (1 + |x|) rather than x_1 / x_2 to avoid triggering large 778 // division errors in the numeric estimator used by the gradient checker. 779 auto y = Div(scope_, x, Add(scope_, Const<float>(scope_, 1), Abs(scope_, x))); 780 RunTest({x}, {x_shape}, {y}, {x_shape}); 781} 782 783TEST_F(NaryGradTest, RealDiv) { 784 TensorShape x_shape({3, 2, 5}); 785 auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); 786 // Test x / (1 + |x|) rather than x_1 / x_2 to avoid triggering large 787 // division errors in the numeric estimator used by the gradient checker. 788 auto y = 789 RealDiv(scope_, x, Add(scope_, Const<float>(scope_, 1), Abs(scope_, x))); 790 RunTest({x}, {x_shape}, {y}, {x_shape}); 791} 792 793TEST_F(NaryGradTest, SquaredDifference) { 794 TensorShape x1_shape({3, 2, 5}); 795 TensorShape x2_shape({2, 5}); 796 auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x1_shape)); 797 auto x2 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x2_shape)); 798 auto y = SquaredDifference(scope_, x1, x2); 799 RunTest({x1, x2}, {x1_shape, x2_shape}, {y}, {x1_shape}); 800} 801 802TEST_F(NaryGradTest, Maximum) { 803 TensorShape shape({3, 2}); 804 auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); 805 auto y = Maximum(scope_, x, Const(scope_, 1.0f)); 806 // Select values away from 1.0f to avoid instability when computing 807 // finite differences. 808 Tensor x_init_value = 809 test::AsTensor<float>({0.5f, 1.5f, -1.2f, 3.0f, 0.1f, 2.8f}, {3, 2}); 810 RunTest(x, x_init_value, y, shape); 811} 812 813TEST_F(NaryGradTest, Minimum) { 814 TensorShape shape({3, 2}); 815 auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); 816 auto y = Minimum(scope_, x, Const(scope_, 1.0f)); 817 // Select values away from 1.0f to avoid instability when computing 818 // finite differences. 819 Tensor x_init_value = 820 test::AsTensor<float>({0.5f, 1.5f, -1.2f, 3.0f, 0.1f, 2.8f}, {3, 2}); 821 RunTest(x, x_init_value, y, shape); 822} 823 824TEST_F(NaryGradTest, Lgamma) { 825 TensorShape shape({3, 2}); 826 auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); 827 auto y = Lgamma(scope_, x); 828 // Select values to avoid instability when computing finite differences. 829 // Ref: https://en.wikipedia.org/wiki/File:Gamma_plot.svg 830 Tensor x_init_value = 831 test::AsTensor<float>({-3.5f, -2.5f, -1.5f, 1.0f, 2.0f, 3.5f}, {3, 2}); 832 RunTest(x, x_init_value, y, shape); 833 // TODO(suharshs): add test case for complex values 834} 835 836} // namespace 837} // namespace tensorflow 838