math_grad_test.cc revision 355e25ebcab64e833dfc987638c3e6c79d838266
1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include "tensorflow/cc/framework/grad_op_registry.h" 17#include "tensorflow/cc/framework/gradient_checker.h" 18#include "tensorflow/cc/framework/testutil.h" 19#include "tensorflow/cc/gradients/grad_testutil.h" 20#include "tensorflow/cc/ops/standard_ops.h" 21#include "tensorflow/core/framework/tensor_testutil.h" 22#include "tensorflow/core/lib/core/status_test_util.h" 23#include "tensorflow/core/lib/random/random.h" 24 25namespace tensorflow { 26using namespace ops; // NOLINT(build/namespaces) 27 28namespace { 29 30// TODO(andydavis) Test gradient function against numeric gradients output. 31// TODO(andydavis) As more gradients are added move common test functions 32// to a testutil library. 33 34class CWiseUnaryGradTest : public ::testing::Test { 35 protected: 36 CWiseUnaryGradTest() : scope_(Scope::NewRootScope().WithDevice("/cpu:0")) {} 37 38 enum UnaryOpType { 39 ABS, 40 NEG, 41 INV, 42 SQUARE, 43 SQRT, 44 RSQRT, 45 EXP, 46 EXPM1, 47 LOG, 48 LOG1P, 49 SINH, 50 COSH, 51 TANH, 52 ASINH, 53 ACOSH, 54 ATANH, 55 SIGMOID, 56 SIGN, 57 SIN, 58 COS, 59 ASIN, 60 ACOS, 61 TAN, 62 ATAN, 63 REAL, 64 IMAG, 65 CONJ, 66 COMPLEX, 67 ANGLE, 68 LGAMMA, 69 ERF 70 }; 71 72 template <typename X_T, typename Y_T> 73 void TestCWiseGrad(UnaryOpType op_type, const std::function<X_T(int)>& x_fn) { 74 TF_ASSERT_OK(scope_.status()); 75 DataType x_type = DataTypeToEnum<X_T>::v(); 76 TensorShape shape({2, 3, 2}); 77 auto x = Placeholder(scope_, x_type, Placeholder::Shape(shape)); 78 Tensor x_data(x_type, shape); 79 auto x_data_flat = x_data.flat<X_T>(); 80 for (int i = 0; i < x_data_flat.size(); ++i) { 81 x_data_flat(i) = x_fn(i); 82 } 83 84 Output y; 85 switch (op_type) { 86 case ABS: 87 y = Abs(scope_, x); 88 break; 89 case NEG: 90 y = Neg(scope_, x); 91 break; 92 case INV: 93 y = Reciprocal(scope_, x); 94 break; 95 case SQUARE: 96 y = Square(scope_, x); 97 break; 98 case SQRT: 99 y = Sqrt(scope_, x); 100 break; 101 case RSQRT: 102 y = Rsqrt(scope_, x); 103 break; 104 case EXP: 105 y = Exp(scope_, x); 106 break; 107 case EXPM1: 108 y = Expm1(scope_, x); 109 break; 110 case LOG: 111 y = Log(scope_, x); 112 break; 113 case LOG1P: 114 y = Log1p(scope_, x); 115 break; 116 case SINH: 117 y = Sinh(scope_, x); 118 break; 119 case COSH: 120 y = Cosh(scope_, x); 121 break; 122 case TANH: 123 y = Tanh(scope_, x); 124 break; 125 case ASINH: 126 y = Asinh(scope_, x); 127 break; 128 case ACOSH: 129 y = Acosh(scope_, x); 130 break; 131 case ATANH: 132 y = Atanh(scope_, x); 133 break; 134 case SIGMOID: 135 y = Sigmoid(scope_, x); 136 break; 137 case SIGN: 138 y = Sign(scope_, x); 139 break; 140 case SIN: 141 y = Sin(scope_, x); 142 break; 143 case COS: 144 y = Cos(scope_, x); 145 break; 146 case ASIN: 147 y = Asin(scope_, x); 148 break; 149 case ACOS: 150 y = Acos(scope_, x); 151 break; 152 case TAN: 153 y = Tan(scope_, x); 154 break; 155 case ATAN: 156 y = Atan(scope_, x); 157 break; 158 case REAL: 159 y = Real(scope_, x); 160 break; 161 case IMAG: 162 y = Imag(scope_, x); 163 break; 164 case CONJ: 165 y = Conj(scope_, x); 166 break; 167 case COMPLEX: 168 y = Complex(scope_, x, x); 169 break; 170 case ANGLE: 171 y = Angle(scope_, x); 172 break; 173 case LGAMMA: 174 y = Lgamma(scope_, x); 175 break; 176 case ERF: 177 y = Erf(scope_, x); 178 break; 179 } 180 181 float max_error; 182 TF_ASSERT_OK((ComputeGradientError<X_T, Y_T, float>(scope_, x, x_data, y, 183 shape, &max_error))); 184 EXPECT_LT(max_error, 1e-3f); 185 } 186 187 float RV(const std::vector<float>& v) { 188 return v[random::New64() % v.size()]; 189 } 190 191 complex64 CRV(const std::vector<complex64>& v) { 192 return v[random::New64() % v.size()]; 193 } 194 195 complex64 conjugate(const complex64& val) { 196 return complex64(val.real(), -val.imag()); 197 } 198 199 Scope scope_; 200}; 201 202TEST_F(CWiseUnaryGradTest, Abs) { 203 auto x_fn = [this](const int i) { return RV({-1, 0, 1}); }; 204 TestCWiseGrad<float, float>(ABS, x_fn); 205} 206 207TEST_F(CWiseUnaryGradTest, Neg) { 208 auto x_fn = [this](const int i) { return RV({-1, 0, 1}); }; 209 TestCWiseGrad<float, float>(NEG, x_fn); 210} 211 212TEST_F(CWiseUnaryGradTest, Reciprocal) { 213 auto x_fn = [this](const int i) { return RV({-1, 1, -2, 2, -3, 3, -4, 4}); }; 214 TestCWiseGrad<float, float>(INV, x_fn); 215} 216 217TEST_F(CWiseUnaryGradTest, Reciprocal_Complex) { 218 auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); }; 219 TestCWiseGrad<complex64, complex64>(INV, x_fn); 220} 221 222TEST_F(CWiseUnaryGradTest, Square) { 223 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 224 TestCWiseGrad<float, float>(SQUARE, x_fn); 225} 226 227TEST_F(CWiseUnaryGradTest, Square_Complex) { 228 auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); }; 229 TestCWiseGrad<complex64, complex64>(SQUARE, x_fn); 230} 231 232TEST_F(CWiseUnaryGradTest, Sqrt) { 233 auto x_fn = [this](const int i) { return RV({0.5, 1, 2, 3, 4, 5, 6, 7}); }; 234 TestCWiseGrad<float, float>(SQRT, x_fn); 235} 236 237TEST_F(CWiseUnaryGradTest, Sqrt_Complex) { 238 auto x_fn = [this](const int i) { 239 return CRV({{-1.0f, 0.5f}, {1.0f, 0.5f}, {2, -1}}); 240 }; 241 TestCWiseGrad<complex64, complex64>(SQRT, x_fn); 242} 243 244TEST_F(CWiseUnaryGradTest, Rsqrt) { 245 auto x_fn = [this](const int i) { return RV({1, 2, 3, 4, 5, 6, 7, 8}); }; 246 TestCWiseGrad<float, float>(RSQRT, x_fn); 247} 248 249TEST_F(CWiseUnaryGradTest, Rsqrt_Complex) { 250 auto x_fn = [this](const int i) { 251 return CRV({{-1.0f, 0.5f}, {1.0f, 0.5f}, {2, -1}}); 252 }; 253 TestCWiseGrad<complex64, complex64>(RSQRT, x_fn); 254} 255 256TEST_F(CWiseUnaryGradTest, Exp) { 257 auto x_fn = [this](const int i) { 258 return RV({0, -1, 1, -1.5f, 1.5f, -2, 2}); 259 }; 260 TestCWiseGrad<float, float>(EXP, x_fn); 261} 262 263TEST_F(CWiseUnaryGradTest, Exp_Complex) { 264 auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); }; 265 TestCWiseGrad<complex64, complex64>(EXP, x_fn); 266} 267 268TEST_F(CWiseUnaryGradTest, Expm1) { 269 auto x_fn = [this](const int i) { return RV({0, -1, 1e-6, 1, -1.5, 1.5}); }; 270 TestCWiseGrad<float, float>(EXPM1, x_fn); 271} 272 273TEST_F(CWiseUnaryGradTest, Expm1_Complex) { 274 auto x_fn = [this](const int i) { 275 return CRV({{-1, 0}, {1, 0}, {1.5, -1.5}}); 276 }; 277 TestCWiseGrad<complex64, complex64>(EXPM1, x_fn); 278} 279 280TEST_F(CWiseUnaryGradTest, Log) { 281 auto x_fn = [this](const int i) { return RV({0.5, 1, 2, 3, 4}); }; 282 TestCWiseGrad<float, float>(LOG, x_fn); 283} 284 285TEST_F(CWiseUnaryGradTest, Log_Complex) { 286 auto x_fn = [this](const int i) { 287 return CRV({{-1, 0.5f}, {1, 0.5f}, {2, -1}}); 288 }; 289 TestCWiseGrad<complex64, complex64>(LOG, x_fn); 290} 291 292TEST_F(CWiseUnaryGradTest, Log1p) { 293 auto x_fn = [this](const int i) { return RV({0, 1e-6, 1, 2, 3, 4, 100}); }; 294 TestCWiseGrad<float, float>(LOG1P, x_fn); 295} 296 297TEST_F(CWiseUnaryGradTest, Log1p_Complex) { 298 auto x_fn = [this](const int i) { 299 return CRV({{0, 0}, {1e-6, 0}, {2, -1}, {1, 2}, {3, 4}}); 300 }; 301 TestCWiseGrad<complex64, complex64>(LOG1P, x_fn); 302} 303 304TEST_F(CWiseUnaryGradTest, Sinh) { 305 auto x_fn = [this](const int i) { return RV({0.5, -0.5, 1, -1, 1.5, -1.5}); }; 306 TestCWiseGrad<float, float>(SINH, x_fn); 307} 308 309TEST_F(CWiseUnaryGradTest, Sinh_Complex) { 310 auto x_fn = [this](const int i) { 311 return CRV({{0.5, 0.25}, {0.25, 0.5}, {1.5, -1}, {1, 1.5}}); 312 }; 313 TestCWiseGrad<complex64, complex64>(SINH, x_fn); 314} 315 316TEST_F(CWiseUnaryGradTest, Cosh) { 317 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 318 TestCWiseGrad<float, float>(COSH, x_fn); 319} 320 321TEST_F(CWiseUnaryGradTest, Cosh_Complex) { 322 auto x_fn = [this](const int i) { 323 return CRV({{0.5, 0.25}, {0.25, 0.5}, {1.5, -1}, {1, 1.5}}); 324 }; 325 TestCWiseGrad<complex64, complex64>(COSH, x_fn); 326} 327 328TEST_F(CWiseUnaryGradTest, Tanh) { 329 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 330 TestCWiseGrad<float, float>(TANH, x_fn); 331} 332 333TEST_F(CWiseUnaryGradTest, Tanh_Complex) { 334 auto x_fn = [this](const int i) { 335 return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); 336 }; 337 TestCWiseGrad<complex64, complex64>(TANH, x_fn); 338} 339 340TEST_F(CWiseUnaryGradTest, Asinh) { 341 auto x_fn = [this](const int i) { return RV({0.5, 1, -1, -1.5, 1.5}); }; 342 TestCWiseGrad<float, float>(ASINH, x_fn); 343} 344 345TEST_F(CWiseUnaryGradTest, Asinh_Complex) { 346 auto x_fn = [this](const int i) { 347 return CRV({{1, 0.5}, {0.5, 1}, {0.5, -1}, {1, 1.5}}); 348 }; 349 TestCWiseGrad<complex64, complex64>(ASINH, x_fn); 350} 351 352TEST_F(CWiseUnaryGradTest, Acosh) { 353 auto x_fn = [this](const int i) { return RV({1.5, 2, 2.5}); }; 354 TestCWiseGrad<float, float>(ACOSH, x_fn); 355} 356 357TEST_F(CWiseUnaryGradTest, Acosh_Complex) { 358 auto x_fn = [this](const int i) { 359 return CRV({{1, 0.5}, {0.5, 1}, {0.5, -1}, {1, 1.5}}); 360 }; 361 TestCWiseGrad<complex64, complex64>(ACOSH, x_fn); 362} 363 364TEST_F(CWiseUnaryGradTest, Atanh) { 365 auto x_fn = [this](const int i) { return RV({0, -0.5, 0.5, -0.1, 0.1}); }; 366 TestCWiseGrad<float, float>(ATANH, x_fn); 367} 368 369TEST_F(CWiseUnaryGradTest, Atanh_Complex) { 370 auto x_fn = [this](const int i) { 371 return CRV({{0.1, 0}, {0, 0.1}, {0.2, -0.1}, {0.1, 0.2}, {0.3, 0.4}}); 372 }; 373 TestCWiseGrad<complex64, complex64>(ATANH, x_fn); 374} 375 376TEST_F(CWiseUnaryGradTest, Sigmoid) { 377 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 378 TestCWiseGrad<float, float>(SIGMOID, x_fn); 379} 380 381TEST_F(CWiseUnaryGradTest, Sigmoid_Complex) { 382 auto x_fn = [this](const int i) { 383 return CRV({{1, 0}, {0, 0}, {2, -1}, {1, 2}, {3, 4}}); 384 }; 385 TestCWiseGrad<complex64, complex64>(SIGMOID, x_fn); 386} 387 388TEST_F(CWiseUnaryGradTest, Sign) { 389 auto x_fn = [this](const int i) { return RV({-1, 1, -2, 2, -3, 3}); }; 390 TestCWiseGrad<float, float>(SIGN, x_fn); 391} 392 393TEST_F(CWiseUnaryGradTest, Sin) { 394 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 395 TestCWiseGrad<float, float>(SIN, x_fn); 396} 397 398TEST_F(CWiseUnaryGradTest, Sin_Complex) { 399 auto x_fn = [this](const int i) { 400 return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}}); 401 }; 402 TestCWiseGrad<complex64, complex64>(SIN, x_fn); 403} 404 405TEST_F(CWiseUnaryGradTest, Cos) { 406 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 407 TestCWiseGrad<float, float>(COS, x_fn); 408} 409 410TEST_F(CWiseUnaryGradTest, Cos_Complex) { 411 auto x_fn = [this](const int i) { 412 return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}}); 413 }; 414 TestCWiseGrad<complex64, complex64>(COS, x_fn); 415} 416 417TEST_F(CWiseUnaryGradTest, Asin) { 418 auto x_fn = [this](const int i) { return RV({0, 0.25, -0.25, -0.5, 0.5}); }; 419 TestCWiseGrad<float, float>(ASIN, x_fn); 420} 421 422TEST_F(CWiseUnaryGradTest, Asin_Complex) { 423 auto x_fn = [this](const int i) { 424 return CRV({{0.5, 0}, {0, 0.5}, {0.25, -0.75}, {0.5, 0.25}}); 425 }; 426 // TODO(kbsriram) 427 // Enable test when the asin kernel supports complex numbers 428 if (false) { 429 TestCWiseGrad<complex64, complex64>(ASIN, x_fn); 430 } 431} 432 433TEST_F(CWiseUnaryGradTest, Acos) { 434 auto x_fn = [this](const int i) { return RV({0, -0.5, 0.5, -0.75, 0.75}); }; 435 TestCWiseGrad<float, float>(ACOS, x_fn); 436} 437 438TEST_F(CWiseUnaryGradTest, Acos_Complex) { 439 auto x_fn = [this](const int i) { 440 return CRV({{0.5, 0}, {0, 0.5}, {0.25, -0.75}, {0.5, 0.25}}); 441 }; 442 // TODO(kbsriram) 443 // Add test when the acos kernel supports complex numbers 444 if (false) { 445 TestCWiseGrad<complex64, complex64>(ACOS, x_fn); 446 } 447} 448 449TEST_F(CWiseUnaryGradTest, Tan) { 450 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 451 TestCWiseGrad<float, float>(TAN, x_fn); 452} 453 454TEST_F(CWiseUnaryGradTest, Tan_Complex) { 455 auto x_fn = [this](const int i) { 456 return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); 457 }; 458 // TODO(kbsriram) 459 // Enable when tan kernel supports complex inputs 460 if (false) { 461 TestCWiseGrad<complex64, complex64>(TAN, x_fn); 462 } 463} 464 465TEST_F(CWiseUnaryGradTest, Atan) { 466 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 467 TestCWiseGrad<float, float>(ATAN, x_fn); 468} 469 470TEST_F(CWiseUnaryGradTest, Atan_Complex) { 471 auto x_fn = [this](const int i) { 472 return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); 473 }; 474 // TODO(kbsriram) 475 // Add test when the atan kernel supports complex numbers 476 if (false) { 477 TestCWiseGrad<complex64, complex64>(ATAN, x_fn); 478 } 479} 480 481TEST_F(CWiseUnaryGradTest, Real) { 482 auto x_fn = [this](const int i) { 483 return CRV({{1, -1}, {-2, 2}, {2, 3}, {-2, -3}}); 484 }; 485 TestCWiseGrad<complex64, float>(REAL, x_fn); 486} 487 488TEST_F(CWiseUnaryGradTest, Imag) { 489 auto x_fn = [this](const int i) { 490 return CRV({{1, -1}, {-2, 2}, {2, 3}, {-2, -3}}); 491 }; 492 TestCWiseGrad<complex64, float>(IMAG, x_fn); 493} 494 495TEST_F(CWiseUnaryGradTest, Conj) { 496 auto x_fn = [this](const int i) { 497 return CRV({{1, -1}, {-2, 2}, {2, 3}, {-2, -3}}); 498 }; 499 TestCWiseGrad<complex64, complex64>(CONJ, x_fn); 500} 501 502TEST_F(CWiseUnaryGradTest, Complex) { 503 auto x_fn = [this](const int i) { return RV({1, -1, 2, -2, 3, -3}); }; 504 TestCWiseGrad<float, complex64>(COMPLEX, x_fn); 505} 506 507TEST_F(CWiseUnaryGradTest, Angle) { 508 auto x_fn = [this](const int i) { 509 return CRV({{1.5, 1.5}, {1.5, -1.5}, {-1.5, 1.5}, {-1.5, -1.5}}); 510 }; 511 TestCWiseGrad<complex64, float>(ANGLE, x_fn); 512} 513 514TEST_F(CWiseUnaryGradTest, Lgamma) { 515 auto x_fn = [this](const int i) { 516 return RV({-3.5, -2.5, -1.5, 1.0, 2.0, 3.5}); 517 }; 518 TestCWiseGrad<float, float>(LGAMMA, x_fn); 519} 520 521TEST_F(CWiseUnaryGradTest, Lgamma_Complex) { 522 auto x_fn = [this](const int i) { 523 return CRV({{-3.5, 0.5}, {-1.5, -0.5}, {1.5, -1.0}, {3.5, 1.0}}); 524 }; 525 // TODO(kbsriram) 526 // Add test when the lgamma kernel supports complex numbers 527 if (false) { 528 TestCWiseGrad<complex64, complex64>(LGAMMA, x_fn); 529 } 530} 531 532TEST_F(CWiseUnaryGradTest, Erf) { 533 auto x_fn = [this](const int i) { 534 return RV({-1.2, -1.0, -0.5, 0.3, 0.5, 1.3}); 535 }; 536 TestCWiseGrad<float, float>(ERF, x_fn); 537} 538 539TEST_F(CWiseUnaryGradTest, Erf_Complex) { 540 auto x_fn = [this](const int i) { 541 return CRV({{-1.2, 0.5}, {-0.5, -0.5}, {0.5, 0.5}, {1.2, -0.5}}); 542 }; 543 // TODO(kbsriram) 544 // Add test when the erf kernel supports complex numbers 545 if (false) { 546 TestCWiseGrad<complex64, complex64>(ERF, x_fn); 547 } 548} 549 550class MathGradTest : public ::testing::Test { 551 protected: 552 MathGradTest() : root_(Scope::NewRootScope().WithDevice("/cpu:0")) {} 553 554 template <typename T> 555 void TestMatMulGrad(const bool is_batch, const bool t_x, const bool t_y) { 556 TF_ASSERT_OK(root_.status()); 557 // Generate random (but compatible) shapes for matrix multiplication. 558 std::vector<TensorShape> shapes; 559 RandMatMulShapes(is_batch, t_x, t_y, &shapes); 560 TensorShape x_shape = shapes[0]; 561 TensorShape y_shape = shapes[1]; 562 TensorShape z_shape = shapes[2]; 563 auto x = 564 Placeholder(root_, DataTypeToEnum<T>::v(), Placeholder::Shape(x_shape)); 565 auto y = 566 Placeholder(root_, DataTypeToEnum<T>::v(), Placeholder::Shape(y_shape)); 567 Output z; 568 if (is_batch) { 569 z = BatchMatMul(root_, x, y, BatchMatMul::AdjX(t_x).AdjY(t_y)); 570 } else { 571 z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y)); 572 } 573 574 float max_error; 575 TF_ASSERT_OK((ComputeGradientError<T, T, float>( 576 root_, {x, y}, {x_shape, y_shape}, {z}, {z_shape}, &max_error))); 577 EXPECT_LT(max_error, 1e-3); 578 } 579 580 void RandMatMulShapes(const bool is_batch, const bool tx, const bool ty, 581 std::vector<TensorShape>* shapes) { 582 // Choose a random batch size in [1, 4] 583 const int b = 1 + (random::New64() % 4); 584 // z = MatMul(x, y) 585 const int m = Rand(); 586 const int k = Rand(); 587 const int n = Rand(); 588 589 TensorShape x_shape; 590 if (is_batch) { 591 // x.shape = [b, m, k] 592 x_shape = tx ? TensorShape({b, k, m}) : TensorShape({b, m, k}); 593 } else { 594 // x.shape = [m, k] 595 x_shape = tx ? TensorShape({k, m}) : TensorShape({m, k}); 596 } 597 shapes->push_back(x_shape); 598 599 TensorShape y_shape; 600 if (is_batch) { 601 // y.shape = [b, k, n] 602 y_shape = ty ? TensorShape({b, n, k}) : TensorShape({b, k, n}); 603 } else { 604 // y.shape = [k, n] 605 y_shape = ty ? TensorShape({n, k}) : TensorShape({k, n}); 606 } 607 shapes->push_back(y_shape); 608 609 TensorShape z_shape; 610 if (is_batch) { 611 // z.shape = [b, m, n] 612 z_shape = TensorShape({b, m, n}); 613 } else { 614 // z.shape = [m, n] 615 z_shape = TensorShape({m, n}); 616 } 617 shapes->push_back(z_shape); 618 } 619 620 int Rand() { return 1 + (random::New64() % 10); } 621 622 Scope root_; 623}; 624 625TEST_F(MathGradTest, MatMulGrad_NoTranspose) { 626 TestMatMulGrad<float>(false, false, false); 627} 628 629TEST_F(MathGradTest, MatMulComplexGrad_NoTranspose) { 630 TestMatMulGrad<complex64>(false, false, false); 631} 632 633TEST_F(MathGradTest, MatMulGrad_TransposeX) { 634 TestMatMulGrad<float>(false, true, false); 635} 636 637TEST_F(MathGradTest, MatMulComplexGrad_TransposeX) { 638 TestMatMulGrad<complex64>(false, true, false); 639} 640 641TEST_F(MathGradTest, MatMulGrad_TransposeY) { 642 TestMatMulGrad<float>(false, false, true); 643} 644 645TEST_F(MathGradTest, MatMulComplexGrad_TransposeY) { 646 TestMatMulGrad<complex64>(false, false, true); 647} 648 649TEST_F(MathGradTest, MatMulGrad_TransposeX_TransposeY) { 650 TestMatMulGrad<float>(false, true, true); 651} 652 653TEST_F(MathGradTest, MatMulComplexGrad_TransposeX_TransposeY) { 654 TestMatMulGrad<complex64>(false, true, true); 655} 656 657TEST_F(MathGradTest, BatchMatMulGrad_NoTranspose) { 658 TestMatMulGrad<float>(true, false, false); 659} 660 661TEST_F(MathGradTest, BatchMatMulComplexGrad_NoTranspose) { 662 TestMatMulGrad<complex64>(true, false, false); 663} 664 665TEST_F(MathGradTest, BatchMatMulGrad_TransposeX) { 666 TestMatMulGrad<float>(true, true, false); 667} 668 669TEST_F(MathGradTest, BatchMatMulComplexGrad_TransposeX) { 670 TestMatMulGrad<complex64>(true, true, false); 671} 672 673TEST_F(MathGradTest, BatchMatMulGrad_TransposeY) { 674 TestMatMulGrad<float>(true, false, true); 675} 676 677TEST_F(MathGradTest, BatchMatMulComplexGrad_TransposeY) { 678 TestMatMulGrad<complex64>(true, false, true); 679} 680 681TEST_F(MathGradTest, BatchMatMulGrad_TransposeX_TransposeY) { 682 TestMatMulGrad<float>(true, true, true); 683} 684 685TEST_F(MathGradTest, BatchMatMulComplexGrad_TransposeX_TransposeY) { 686 TestMatMulGrad<complex64>(true, true, true); 687} 688 689class NaryGradTest : public ::testing::Test { 690 protected: 691 NaryGradTest() : scope_(Scope::NewRootScope().WithDevice("/cpu:0")) {} 692 693 void RunTest(const OutputList& xs, const std::vector<TensorShape>& x_shapes, 694 const OutputList& ys, const std::vector<TensorShape>& y_shapes) { 695 TF_ASSERT_OK(scope_.status()); 696 float max_error; 697 TF_ASSERT_OK((ComputeGradientError<float, float, float>( 698 scope_, xs, x_shapes, ys, y_shapes, &max_error))); 699 EXPECT_LT(max_error, 1e-3); 700 } 701 702 void RunTest(const Output& x, const Tensor& x_init_value, const Output& y, 703 const TensorShape& y_shape) { 704 TF_ASSERT_OK(scope_.status()); 705 float max_error; 706 TF_ASSERT_OK((ComputeGradientError<float, float, float>( 707 scope_, x, x_init_value, y, y_shape, &max_error))); 708 EXPECT_LT(max_error, 1e-3); 709 } 710 711 Scope scope_; 712}; 713 714TEST_F(NaryGradTest, Sum) { 715 TensorShape x_shape({2, 3, 5, 7}); 716 auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); 717 auto y = Sum(scope_, x, {1, -1}); 718 // y's shape is the result of reducing x along axes 1 and -1 (= 3) 719 TensorShape y_shape({2, 5}); 720 RunTest({x}, {x_shape}, {y}, {y_shape}); 721} 722 723TEST_F(NaryGradTest, Mean) { 724 TensorShape x_shape({2, 3, 5, 7}); 725 auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); 726 auto y = Mean(scope_, x, {1, -1}); 727 // y's shape is the result of reducing x along axes 1 and -1 (= 3) 728 TensorShape y_shape({2, 5}); 729 RunTest({x}, {x_shape}, {y}, {y_shape}); 730} 731 732TEST_F(NaryGradTest, Min) { 733 TensorShape x_shape({2, 3}); 734 auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); 735 auto y = Min(scope_, x, {-1}); 736 // y's shape is the result of reducing x along axes -1 (= 1) 737 TensorShape y_shape({2}); 738 Tensor x_init_value = 739 test::AsTensor<float>({0.5f, 0.7f, 0.2f, 1.0f, 1.5f, -2.8f}, x_shape); 740 RunTest(x, x_init_value, y, y_shape); 741} 742 743TEST_F(NaryGradTest, Max) { 744 TensorShape x_shape({2, 3}); 745 auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); 746 auto y = Max(scope_, x, {-1}); 747 // y's shape is the result of reducing x along axes -1 (= 1) 748 TensorShape y_shape({2}); 749 Tensor x_init_value = 750 test::AsTensor<float>({0.5f, 0.7f, 0.2f, 1.0f, 1.5f, -2.8f}, x_shape); 751 RunTest(x, x_init_value, y, y_shape); 752} 753 754TEST_F(NaryGradTest, MinMulti) { 755 // Test gradient when there are multiple minima. 756 // Note that we cannot directly use a test Tensor with multiple 757 // minima, as the numeric estimator will calculate incorrect 758 // gradients when perturbing each entry in the Tensor (which then 759 // changes how many minima exist.) 760 // Instead, we use a single input that broadcast-multiplies a larger 761 // tensor with equal values, and apply reduce_min to the multiplied 762 // result. 763 TensorShape x_shape({1}); 764 auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); 765 auto all_same = Mul(scope_, Const(scope_, {1.f, 1.f, 1.f}), x); 766 auto y = Min(scope_, all_same, {0}); 767 // y is a [3] shaped tensor reduced along dimension 0, so it is [1] shaped 768 TensorShape y_shape({1}); 769 RunTest({x}, {x_shape}, {y}, {y_shape}); 770} 771 772TEST_F(NaryGradTest, MaxMulti) { 773 TensorShape x_shape({1}); 774 auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); 775 auto all_same = Mul(scope_, Const(scope_, {1.f, 1.f, 1.f}), x); 776 auto y = Max(scope_, all_same, {0}); 777 TensorShape y_shape({1}); 778 RunTest({x}, {x_shape}, {y}, {y_shape}); 779} 780 781TEST_F(NaryGradTest, AddN) { 782 TensorShape shape({3, 2, 5}); 783 std::vector<Output> xs; 784 xs.push_back(Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape))); 785 xs.push_back(Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape))); 786 xs.push_back(Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape))); 787 auto y = AddN(scope_, xs); 788 RunTest(xs, {shape, shape, shape}, {y}, {shape}); 789} 790 791TEST_F(NaryGradTest, Add) { 792 TensorShape x1_shape({3, 2, 5}); 793 TensorShape x2_shape({2, 5}); 794 auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x1_shape)); 795 auto x2 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x2_shape)); 796 auto y = Add(scope_, x1, x2); 797 RunTest({x1, x2}, {x1_shape, x2_shape}, {y}, {x1_shape}); 798} 799 800TEST_F(NaryGradTest, Sub) { 801 TensorShape x1_shape({3, 2, 5}); 802 TensorShape x2_shape({2, 5}); 803 auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x1_shape)); 804 auto x2 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x2_shape)); 805 auto y = Sub(scope_, x1, x2); 806 RunTest({x1, x2}, {x1_shape, x2_shape}, {y}, {x1_shape}); 807} 808 809TEST_F(NaryGradTest, Mul) { 810 TensorShape x1_shape({3, 2, 5}); 811 TensorShape x2_shape({2, 5}); 812 auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x1_shape)); 813 auto x2 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x2_shape)); 814 auto y = Mul(scope_, x1, x2); 815 RunTest({x1, x2}, {x1_shape, x2_shape}, {y}, {x1_shape}); 816} 817 818TEST_F(NaryGradTest, Div) { 819 TensorShape x_shape({3, 2, 5}); 820 auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); 821 // Test x / (1 + |x|) rather than x_1 / x_2 to avoid triggering large 822 // division errors in the numeric estimator used by the gradient checker. 823 auto y = Div(scope_, x, Add(scope_, Const<float>(scope_, 1), Abs(scope_, x))); 824 RunTest({x}, {x_shape}, {y}, {x_shape}); 825} 826 827TEST_F(NaryGradTest, RealDiv) { 828 TensorShape x_shape({3, 2, 5}); 829 auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); 830 // Test x / (1 + |x|) rather than x_1 / x_2 to avoid triggering large 831 // division errors in the numeric estimator used by the gradient checker. 832 auto y = 833 RealDiv(scope_, x, Add(scope_, Const<float>(scope_, 1), Abs(scope_, x))); 834 RunTest({x}, {x_shape}, {y}, {x_shape}); 835} 836 837TEST_F(NaryGradTest, SquaredDifference) { 838 TensorShape x1_shape({3, 2, 5}); 839 TensorShape x2_shape({2, 5}); 840 auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x1_shape)); 841 auto x2 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x2_shape)); 842 auto y = SquaredDifference(scope_, x1, x2); 843 RunTest({x1, x2}, {x1_shape, x2_shape}, {y}, {x1_shape}); 844} 845 846TEST_F(NaryGradTest, Maximum) { 847 TensorShape shape({3, 2}); 848 auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); 849 auto y = Maximum(scope_, x, Const(scope_, 1.0f)); 850 // Select values away from 1.0f to avoid instability when computing 851 // finite differences. 852 Tensor x_init_value = 853 test::AsTensor<float>({0.5f, 1.5f, -1.2f, 3.0f, 0.1f, 2.8f}, {3, 2}); 854 RunTest(x, x_init_value, y, shape); 855} 856 857TEST_F(NaryGradTest, Minimum) { 858 TensorShape shape({3, 2}); 859 auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); 860 auto y = Minimum(scope_, x, Const(scope_, 1.0f)); 861 // Select values away from 1.0f to avoid instability when computing 862 // finite differences. 863 Tensor x_init_value = 864 test::AsTensor<float>({0.5f, 1.5f, -1.2f, 3.0f, 0.1f, 2.8f}, {3, 2}); 865 RunTest(x, x_init_value, y, shape); 866} 867 868} // namespace 869} // namespace tensorflow 870