math_grad_test.cc revision 2a792a35111dbd55757fd592a9913b5048b55468
1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include "tensorflow/cc/framework/grad_op_registry.h" 17#include "tensorflow/cc/framework/gradient_checker.h" 18#include "tensorflow/cc/framework/testutil.h" 19#include "tensorflow/cc/gradients/grad_testutil.h" 20#include "tensorflow/cc/ops/standard_ops.h" 21#include "tensorflow/core/framework/tensor_testutil.h" 22#include "tensorflow/core/lib/core/status_test_util.h" 23#include "tensorflow/core/lib/random/random.h" 24 25namespace tensorflow { 26using namespace ops; // NOLINT(build/namespaces) 27 28namespace { 29 30// TODO(andydavis) Test gradient function against numeric gradients output. 31// TODO(andydavis) As more gradients are added move common test functions 32// to a testutil library. 33 34class CWiseUnaryGradTest : public ::testing::Test { 35 protected: 36 CWiseUnaryGradTest() : scope_(Scope::NewRootScope().WithDevice("/cpu:0")) {} 37 38 enum UnaryOpType { 39 ABS, 40 NEG, 41 INV, 42 SQUARE, 43 SQRT, 44 RSQRT, 45 EXP, 46 EXPM1, 47 LOG, 48 LOG1P, 49 SINH, 50 COSH, 51 TANH, 52 ASINH, 53 ACOSH, 54 ATANH, 55 SIGMOID, 56 SIGN, 57 SIN, 58 COS, 59 ASIN, 60 ACOS, 61 TAN, 62 ATAN 63 }; 64 65 template <typename T> 66 void TestCWiseGrad(UnaryOpType op_type, const std::function<T(int)>& x_fn, 67 const std::function<T(const T&)>& dy_fn, 68 const std::function<T(const T&, const T&)>& dx_fn) { 69 DataType dtype = DataTypeToEnum<T>::v(); 70 Tensor x(dtype, {2, 3, 2}); 71 auto x_flat = x.flat<T>(); 72 for (int i = 0; i < x_flat.size(); ++i) { 73 x_flat(i) = x_fn(i); 74 } 75 76 Tensor dy(dtype, {2, 3, 2}); 77 auto dy_flat = dy.flat<T>(); 78 for (int i = 0; i < dy_flat.size(); ++i) { 79 dy_flat(i) = dy_fn(x_flat(i)); 80 } 81 82 Tensor dx(dtype, {2, 3, 2}); 83 auto dx_flat = dx.flat<T>(); 84 for (int i = 0; i < dx_flat.size(); ++i) { 85 dx_flat(i) = dx_fn(x_flat(i), dy_flat(i)); 86 } 87 88 Output y; 89 switch (op_type) { 90 case ABS: 91 y = Abs(scope_, x); 92 break; 93 case NEG: 94 y = Neg(scope_, x); 95 break; 96 case INV: 97 y = Reciprocal(scope_, x); 98 break; 99 case SQUARE: 100 y = Square(scope_, x); 101 break; 102 case SQRT: 103 y = Sqrt(scope_, x); 104 break; 105 case RSQRT: 106 y = Rsqrt(scope_, x); 107 break; 108 case EXP: 109 y = Exp(scope_, x); 110 break; 111 case EXPM1: 112 y = Expm1(scope_, x); 113 break; 114 case LOG: 115 y = Log(scope_, x); 116 break; 117 case LOG1P: 118 y = Log1p(scope_, x); 119 break; 120 case SINH: 121 y = Sinh(scope_, x); 122 break; 123 case COSH: 124 y = Cosh(scope_, x); 125 break; 126 case TANH: 127 y = Tanh(scope_, x); 128 break; 129 case ASINH: 130 y = Asinh(scope_, x); 131 break; 132 case ACOSH: 133 y = Acosh(scope_, x); 134 break; 135 case ATANH: 136 y = Atanh(scope_, x); 137 break; 138 case SIGMOID: 139 y = Sigmoid(scope_, x); 140 break; 141 case SIGN: 142 y = Sign(scope_, x); 143 break; 144 case SIN: 145 y = Sin(scope_, x); 146 break; 147 case COS: 148 y = Cos(scope_, x); 149 break; 150 case ASIN: 151 y = Asin(scope_, x); 152 break; 153 case ACOS: 154 y = Acos(scope_, x); 155 break; 156 case TAN: 157 y = Tan(scope_, x); 158 break; 159 case ATAN: 160 y = Atan(scope_, x); 161 break; 162 } 163 164 std::vector<Output> grad_outputs; 165 TF_ASSERT_OK(test::CallGradFunction( 166 scope_, Operation(y.node()), {ops::Const(scope_, dy)}, &grad_outputs)); 167 Tensor output; 168 test::GetTensor(scope_, grad_outputs[0], &output); 169 test::ExpectClose(output, dx); 170 } 171 172 float RV(const std::vector<float>& v) { 173 return v[random::New64() % v.size()]; 174 } 175 176 complex64 CRV(const std::vector<complex64>& v) { 177 return v[random::New64() % v.size()]; 178 } 179 180 complex64 conjugate(const complex64& val) { 181 return complex64(val.real(), -val.imag()); 182 } 183 184 const complex64 one_{1.0, 0}; 185 186 Scope scope_; 187}; 188 189TEST_F(CWiseUnaryGradTest, Abs) { 190 auto x_fn = [this](const int i) { return RV({-1, 0, 1}); }; 191 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 192 auto dx_fn = [this](const float x, const float dy) { return x * dy; }; 193 TestCWiseGrad<float>(ABS, x_fn, dy_fn, dx_fn); 194} 195 196TEST_F(CWiseUnaryGradTest, Neg) { 197 auto x_fn = [this](const int i) { return RV({-1, 0, 1}); }; 198 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 199 auto dx_fn = [this](const float x, const float dy) { return -dy; }; 200 TestCWiseGrad<float>(NEG, x_fn, dy_fn, dx_fn); 201} 202 203TEST_F(CWiseUnaryGradTest, Reciprocal) { 204 auto x_fn = [this](const int i) { return RV({-1, 1, -2, 2, -3, 3, -4, 4}); }; 205 auto dy_fn = [this](const float x) { return RV({0, -2, 2, -3, 3, -4, 4}); }; 206 auto dx_fn = [this](const float x, const float dy) { 207 return -(1 / (x * x)) * dy; 208 }; 209 TestCWiseGrad<float>(INV, x_fn, dy_fn, dx_fn); 210} 211 212TEST_F(CWiseUnaryGradTest, Reciprocal_Complex) { 213 auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); }; 214 auto dy_fn = [this](const complex64 x) { 215 return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); 216 }; 217 auto dx_fn = [this](const complex64 x, const complex64 dy) { 218 return -conjugate(one_ / (x * x)) * dy; 219 }; 220 TestCWiseGrad<complex64>(INV, x_fn, dy_fn, dx_fn); 221} 222 223TEST_F(CWiseUnaryGradTest, Square) { 224 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 225 auto dy_fn = [this](const float x) { return RV({0, -7, 7, -8, 8, -9, 9}); }; 226 auto dx_fn = [this](const float x, const float dy) { return 2 * x * dy; }; 227 TestCWiseGrad<float>(SQUARE, x_fn, dy_fn, dx_fn); 228} 229 230TEST_F(CWiseUnaryGradTest, Square_Complex) { 231 auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); }; 232 auto dy_fn = [this](const complex64& x) { 233 return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); 234 }; 235 auto dx_fn = [this](const complex64& x, const complex64& dy) { 236 return conjugate(complex64(2, 0) * x) * dy; 237 }; 238 TestCWiseGrad<complex64>(SQUARE, x_fn, dy_fn, dx_fn); 239} 240 241TEST_F(CWiseUnaryGradTest, Sqrt) { 242 auto x_fn = [this](const int i) { return RV({0, 1, 2, 3, 4, 5, 6, 7}); }; 243 auto dy_fn = [this](const float x) { return x + RV({8, 9, 10, 11, 12, 13}); }; 244 auto dx_fn = [this](const float x, const float dy) { 245 return dy * 0.5 * (1.0 / std::sqrt(x)); 246 }; 247 TestCWiseGrad<float>(SQRT, x_fn, dy_fn, dx_fn); 248} 249 250TEST_F(CWiseUnaryGradTest, Sqrt_Complex) { 251 auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); }; 252 auto dy_fn = [this](const complex64& x) { 253 return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); 254 }; 255 auto dx_fn = [this](const complex64& x, const complex64& dy) { 256 return conjugate(complex64(0.5, 0) / std::sqrt(x)) * dy; 257 }; 258 TestCWiseGrad<complex64>(SQRT, x_fn, dy_fn, dx_fn); 259} 260 261TEST_F(CWiseUnaryGradTest, Rsqrt) { 262 auto x_fn = [this](const int i) { return RV({1, 2, 3, 4, 5, 6, 7, 8}); }; 263 auto dy_fn = [this](const float x) { return x + RV({8, 9, 10, 11, 12, 13}); }; 264 auto dx_fn = [this](const float x, const float dy) { 265 return dy * -0.5 * (1 / std::sqrt(x)) * (1 / x); 266 }; 267 TestCWiseGrad<float>(RSQRT, x_fn, dy_fn, dx_fn); 268} 269 270TEST_F(CWiseUnaryGradTest, Rsqrt_Complex) { 271 auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); }; 272 auto dy_fn = [this](const complex64& x) { 273 return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); 274 }; 275 auto dx_fn = [this](const complex64& x, const complex64& dy) { 276 return conjugate(complex64(-0.5, 0) / std::sqrt(x) / x) * dy; 277 }; 278 TestCWiseGrad<complex64>(RSQRT, x_fn, dy_fn, dx_fn); 279} 280 281TEST_F(CWiseUnaryGradTest, Exp) { 282 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 283 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 284 auto dx_fn = [this](const float x, const float dy) { 285 return dy * std::exp(x); 286 }; 287 TestCWiseGrad<float>(EXP, x_fn, dy_fn, dx_fn); 288} 289 290TEST_F(CWiseUnaryGradTest, Exp_Complex) { 291 auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); }; 292 auto dy_fn = [this](const complex64& x) { 293 return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); 294 }; 295 auto dx_fn = [this](const complex64& x, const complex64& dy) { 296 return dy * conjugate(std::exp(x)); 297 }; 298 TestCWiseGrad<complex64>(EXP, x_fn, dy_fn, dx_fn); 299} 300 301TEST_F(CWiseUnaryGradTest, Expm1) { 302 auto x_fn = [this](const int i) { return RV({0, -1, 1e-6, 1, -2, 3, 100}); }; 303 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 304 auto dx_fn = [this](const float x, const float dy) { 305 return dy * std::exp(x); 306 }; 307 TestCWiseGrad<float>(EXPM1, x_fn, dy_fn, dx_fn); 308} 309 310TEST_F(CWiseUnaryGradTest, Expm1_Complex) { 311 auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); }; 312 auto dy_fn = [this](const complex64& x) { 313 return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); 314 }; 315 auto dx_fn = [this](const complex64& x, const complex64& dy) { 316 return dy * conjugate(std::exp(x)); 317 }; 318 TestCWiseGrad<complex64>(EXPM1, x_fn, dy_fn, dx_fn); 319} 320 321TEST_F(CWiseUnaryGradTest, Log) { 322 auto x_fn = [this](const int i) { return RV({-1, 1, -2, 2, -3, 3, -4, 4}); }; 323 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 324 auto dx_fn = [this](const float x, const float dy) { return dy * (1.0 / x); }; 325 TestCWiseGrad<float>(LOG, x_fn, dy_fn, dx_fn); 326} 327 328TEST_F(CWiseUnaryGradTest, Log_Complex) { 329 auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); }; 330 auto dy_fn = [this](const complex64& x) { 331 return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); 332 }; 333 auto dx_fn = [this](const complex64& x, const complex64& dy) { 334 return dy * conjugate(one_ / x); 335 }; 336 TestCWiseGrad<complex64>(LOG, x_fn, dy_fn, dx_fn); 337} 338 339TEST_F(CWiseUnaryGradTest, Log1p) { 340 auto x_fn = [this](const int i) { return RV({0, 1e-6, 1, 2, 3, 4, 100}); }; 341 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 342 auto dx_fn = [this](const float x, const float dy) { 343 return dy * (1.0 / (1.0 + x)); 344 }; 345 TestCWiseGrad<float>(LOG1P, x_fn, dy_fn, dx_fn); 346} 347 348TEST_F(CWiseUnaryGradTest, Log1p_Complex) { 349 auto x_fn = [this](const int i) { 350 return CRV({{0, 0}, {1e-6, 0}, {2, -1}, {1, 2}, {3, 4}}); 351 }; 352 auto dy_fn = [this](const complex64& x) { 353 return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); 354 }; 355 auto dx_fn = [this](const complex64& x, const complex64& dy) { 356 return dy / (one_ + conjugate(x)); 357 }; 358 TestCWiseGrad<complex64>(LOG1P, x_fn, dy_fn, dx_fn); 359} 360 361TEST_F(CWiseUnaryGradTest, Sinh) { 362 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 363 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 364 auto dx_fn = [this](const float x, const float dy) { 365 return dy * std::cosh(x); 366 }; 367 TestCWiseGrad<float>(SINH, x_fn, dy_fn, dx_fn); 368} 369 370TEST_F(CWiseUnaryGradTest, Sinh_Complex) { 371 auto x_fn = [this](const int i) { 372 return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); 373 }; 374 auto dy_fn = [this](const complex64& x) { 375 return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); 376 }; 377 auto dx_fn = [this](const complex64& x, const complex64& dy) { 378 return dy * conjugate(std::cosh(x)); 379 }; 380 TestCWiseGrad<complex64>(SINH, x_fn, dy_fn, dx_fn); 381} 382 383TEST_F(CWiseUnaryGradTest, Cosh) { 384 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 385 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 386 auto dx_fn = [this](const float x, const float dy) { 387 return dy * std::sinh(x); 388 }; 389 TestCWiseGrad<float>(COSH, x_fn, dy_fn, dx_fn); 390} 391 392TEST_F(CWiseUnaryGradTest, Cosh_Complex) { 393 auto x_fn = [this](const int i) { 394 return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); 395 }; 396 auto dy_fn = [this](const complex64& x) { 397 return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); 398 }; 399 auto dx_fn = [this](const complex64& x, const complex64& dy) { 400 return dy * conjugate(std::sinh(x)); 401 }; 402 TestCWiseGrad<complex64>(COSH, x_fn, dy_fn, dx_fn); 403} 404 405TEST_F(CWiseUnaryGradTest, Tanh) { 406 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 407 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 408 auto dx_fn = [this](const float x, const float dy) { 409 const float y = std::tanh(x); 410 return dy * (1.0 - y * y); 411 }; 412 TestCWiseGrad<float>(TANH, x_fn, dy_fn, dx_fn); 413} 414 415TEST_F(CWiseUnaryGradTest, Tanh_Complex) { 416 auto x_fn = [this](const int i) { 417 return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); 418 }; 419 auto dy_fn = [this](const complex64& x) { 420 return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); 421 }; 422 auto dx_fn = [this](const complex64& x, const complex64& dy) { 423 const complex64 y = std::tanh(x); 424 return dy * conjugate((one_ - y * y)); 425 }; 426 TestCWiseGrad<complex64>(TANH, x_fn, dy_fn, dx_fn); 427} 428 429TEST_F(CWiseUnaryGradTest, Asinh) { 430 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 431 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 432 auto dx_fn = [this](const float x, const float dy) { 433 auto y = std::asinh(x); 434 return dy / std::cosh(y); 435 }; 436 TestCWiseGrad<float>(ASINH, x_fn, dy_fn, dx_fn); 437} 438 439TEST_F(CWiseUnaryGradTest, Asinh_Complex) { 440 auto x_fn = [this](const int i) { 441 return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); 442 }; 443 auto dy_fn = [this](const complex64& x) { 444 return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); 445 }; 446 auto dx_fn = [this](const complex64& x, const complex64& dy) { 447 auto y = std::asinh(x); 448 return dy / conjugate(std::cosh(y)); 449 }; 450 TestCWiseGrad<complex64>(ASINH, x_fn, dy_fn, dx_fn); 451} 452 453TEST_F(CWiseUnaryGradTest, Acosh) { 454 auto x_fn = [this](const int i) { return RV({1, 2, 3, 4, 5, 6, 7}); }; 455 auto dy_fn = [this](const float x) { 456 return x + RV({8, 9, 10, 11, 12, 13, 14}); 457 }; 458 auto dx_fn = [this](const float x, const float dy) { 459 auto y = std::acosh(x); 460 return dy / std::sinh(y); 461 }; 462 TestCWiseGrad<float>(ACOSH, x_fn, dy_fn, dx_fn); 463} 464 465TEST_F(CWiseUnaryGradTest, Acosh_Complex) { 466 auto x_fn = [this](const int i) { 467 return CRV({{1, 1}, {2, 1}, {1, 4}, {1, 2}, {3, 4}}); 468 }; 469 auto dy_fn = [this](const complex64& x) { 470 return x + CRV({{2, 2}, {3, 3}, {1, 4}}); 471 }; 472 auto dx_fn = [this](const complex64& x, const complex64& dy) { 473 auto y = std::acosh(x); 474 return dy / conjugate(std::sinh(y)); 475 }; 476 TestCWiseGrad<complex64>(ACOSH, x_fn, dy_fn, dx_fn); 477} 478 479TEST_F(CWiseUnaryGradTest, Atanh) { 480 auto x_fn = [this](const int i) { return RV({0, -0.5, 0.5, -0.1, 0.1}); }; 481 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 482 auto dx_fn = [this](const float x, const float dy) { 483 return dy * (1. / (1. - x * x)); 484 }; 485 TestCWiseGrad<float>(ATANH, x_fn, dy_fn, dx_fn); 486} 487 488TEST_F(CWiseUnaryGradTest, Atanh_Complex) { 489 auto x_fn = [this](const int i) { 490 return CRV({{0.1, 0}, {0, 0.1}, {0.2, -0.1}, {0.1, 0.2}, {0.3, 0.4}}); 491 }; 492 auto dy_fn = [this](const complex64& x) { 493 return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); 494 }; 495 auto dx_fn = [this](const complex64& x, const complex64& dy) { 496 return dy / conjugate(one_ - x * x); 497 }; 498 TestCWiseGrad<complex64>(ATANH, x_fn, dy_fn, dx_fn); 499} 500 501TEST_F(CWiseUnaryGradTest, Sigmoid) { 502 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 503 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 504 auto dx_fn = [this](const float x, const float dy) { 505 const float y = 1.0 / (1.0 + std::exp(-x)); 506 return dy * y * (1.0 - y); 507 }; 508 TestCWiseGrad<float>(SIGMOID, x_fn, dy_fn, dx_fn); 509} 510 511TEST_F(CWiseUnaryGradTest, Sigmoid_Complex) { 512 auto x_fn = [this](const int i) { 513 return CRV({{1, 0}, {0, 0}, {2, -1}, {1, 2}, {3, 4}}); 514 }; 515 auto dy_fn = [this](const complex64& x) { 516 return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); 517 }; 518 auto dx_fn = [this](const complex64& x, const complex64& dy) { 519 const complex64 y = one_ / (one_ + std::exp(-x)); 520 return dy * conjugate(y * (one_ - y)); 521 }; 522 TestCWiseGrad<complex64>(SIGMOID, x_fn, dy_fn, dx_fn); 523} 524 525TEST_F(CWiseUnaryGradTest, Sign) { 526 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 527 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 528 auto dx_fn = [this](const float x, const float dy) { return 0.0; }; 529 TestCWiseGrad<float>(SIGN, x_fn, dy_fn, dx_fn); 530} 531 532TEST_F(CWiseUnaryGradTest, Sin) { 533 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 534 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 535 auto dx_fn = [this](const float x, const float dy) { 536 return dy * std::cos(x); 537 }; 538 TestCWiseGrad<float>(SIN, x_fn, dy_fn, dx_fn); 539} 540 541TEST_F(CWiseUnaryGradTest, Sin_Complex) { 542 auto x_fn = [this](const int i) { 543 return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); 544 }; 545 auto dy_fn = [this](const complex64& x) { 546 return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); 547 }; 548 auto dx_fn = [this](const complex64& x, const complex64& dy) { 549 return dy * conjugate(std::cos(x)); 550 }; 551 TestCWiseGrad<complex64>(SIN, x_fn, dy_fn, dx_fn); 552} 553 554TEST_F(CWiseUnaryGradTest, Cos) { 555 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 556 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 557 auto dx_fn = [this](const float x, const float dy) { 558 return dy * -1.0 * std::sin(x); 559 }; 560 TestCWiseGrad<float>(COS, x_fn, dy_fn, dx_fn); 561} 562 563TEST_F(CWiseUnaryGradTest, Cos_Complex) { 564 auto x_fn = [this](const int i) { 565 return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); 566 }; 567 auto dy_fn = [this](const complex64& x) { 568 return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); 569 }; 570 auto dx_fn = [this](const complex64& x, const complex64& dy) { 571 return dy * conjugate(-std::sin(x)); 572 }; 573 TestCWiseGrad<complex64>(COS, x_fn, dy_fn, dx_fn); 574} 575 576TEST_F(CWiseUnaryGradTest, Asin) { 577 auto x_fn = [this](const int i) { return RV({0, -0.5, 0.5, -1, 1}); }; 578 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 579 auto dx_fn = [this](const float x, const float dy) { 580 return dy * (1.0 / std::sqrt(1.0 - x * x)); 581 }; 582 TestCWiseGrad<float>(ASIN, x_fn, dy_fn, dx_fn); 583} 584 585TEST_F(CWiseUnaryGradTest, Asin_Complex) { 586 auto x_fn = [this](const int i) { 587 return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); 588 }; 589 auto dy_fn = [this](const complex64& x) { 590 return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); 591 }; 592 auto dx_fn = [this](const complex64& x, const complex64& dy) { 593 return dy / conjugate(std::sqrt(one_ - x * x)); 594 }; 595 // TODO(kbsriram) 596 // Enable test when the asin kernel supports complex numbers 597 if (false) { 598 TestCWiseGrad<complex64>(ASIN, x_fn, dy_fn, dx_fn); 599 } 600} 601 602TEST_F(CWiseUnaryGradTest, Acos) { 603 auto x_fn = [this](const int i) { return RV({0, -0.5, 0.5, -1, 1}); }; 604 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 605 auto dx_fn = [this](const float x, const float dy) { 606 return dy * (-1.0 / std::sqrt(1.0 - x * x)); 607 }; 608 TestCWiseGrad<float>(ACOS, x_fn, dy_fn, dx_fn); 609} 610 611TEST_F(CWiseUnaryGradTest, Acos_Complex) { 612 auto x_fn = [this](const int i) { 613 return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); 614 }; 615 auto dy_fn = [this](const complex64& x) { 616 return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); 617 }; 618 auto dx_fn = [this](const complex64& x, const complex64& dy) { 619 return dy / -conjugate(std::sqrt(one_ - x * x)); 620 }; 621 // TODO(kbsriram) 622 // Add test when the acos kernel supports complex numbers 623 if (false) { 624 TestCWiseGrad<complex64>(ACOS, x_fn, dy_fn, dx_fn); 625 } 626} 627 628TEST_F(CWiseUnaryGradTest, Tan) { 629 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 630 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 631 auto dx_fn = [this](const float x, const float dy) { 632 const float cosx = std::cos(x); 633 return dy * (1 / (cosx * cosx)); 634 }; 635 TestCWiseGrad<float>(TAN, x_fn, dy_fn, dx_fn); 636} 637 638TEST_F(CWiseUnaryGradTest, Tan_Complex) { 639 auto x_fn = [this](const int i) { 640 return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); 641 }; 642 auto dy_fn = [this](const complex64& x) { 643 return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); 644 }; 645 auto dx_fn = [this](const complex64& x, const complex64& dy) { 646 const complex64 cosx = std::cos(x); 647 return dy / conjugate(cosx * cosx); 648 }; 649 // TODO(kbsriram) 650 // Enable when tan kernel supports complex inputs 651 if (false) { 652 TestCWiseGrad<complex64>(TAN, x_fn, dy_fn, dx_fn); 653 } 654} 655 656TEST_F(CWiseUnaryGradTest, Atan) { 657 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 658 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 659 auto dx_fn = [this](const float x, const float dy) { 660 return dy * (1 / (1 + x * x)); 661 }; 662 TestCWiseGrad<float>(ATAN, x_fn, dy_fn, dx_fn); 663} 664 665TEST_F(CWiseUnaryGradTest, Atan_Complex) { 666 auto x_fn = [this](const int i) { 667 return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); 668 }; 669 auto dy_fn = [this](const complex64& x) { 670 return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); 671 }; 672 auto dx_fn = [this](const complex64& x, const complex64& dy) { 673 return dy / (one_ + x * x); 674 }; 675 // TODO(kbsriram) 676 // Add test when the atan kernel supports complex numbers 677 if (false) { 678 TestCWiseGrad<complex64>(ATAN, x_fn, dy_fn, dx_fn); 679 } 680} 681 682class CWiseUnaryComplexGradTest : public ::testing::Test { 683 protected: 684 CWiseUnaryComplexGradTest() 685 : scope_(Scope::NewRootScope().WithDevice("/cpu:0")) {} 686 687 enum UnaryOpType { REAL, IMAG, CONJ }; 688 689 void TestCWiseGradComplex(UnaryOpType op_type, const Tensor& x, 690 const Tensor& dy, const Tensor& dx_expected) { 691 Output y; 692 switch (op_type) { 693 case REAL: 694 y = Real(scope_, x); 695 break; 696 case IMAG: 697 y = Imag(scope_, x); 698 break; 699 case CONJ: 700 y = Conj(scope_, x); 701 break; 702 } 703 704 std::vector<Output> grad_outputs; 705 TF_ASSERT_OK(test::CallGradFunction( 706 scope_, Operation(y.node()), {ops::Const(scope_, dy)}, &grad_outputs)); 707 Tensor dx; 708 test::GetTensor(scope_, grad_outputs[0], &dx); 709 test::ExpectClose(dx, dx_expected); 710 } 711 712 Scope scope_; 713}; 714 715TEST_F(CWiseUnaryComplexGradTest, Real) { 716 Tensor x = test::AsTensor<complex64>( 717 {{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3}); 718 Tensor dy = test::AsTensor<float>({11, -12, 13, -14, 15, -16}, {2, 3}); 719 Tensor dx_expected = test::AsTensor<complex64>( 720 {{11, 0}, {-12, 0}, {13, 0}, {-14, 0}, {15, 0}, {-16, 0}}, {2, 3}); 721 TestCWiseGradComplex(REAL, x, dy, dx_expected); 722} 723 724TEST_F(CWiseUnaryComplexGradTest, Imag) { 725 Tensor x = test::AsTensor<complex64>( 726 {{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3}); 727 Tensor dy = test::AsTensor<float>({11, -12, 13, -14, 15, -16}, {2, 3}); 728 Tensor dx_expected = test::AsTensor<complex64>( 729 {{0, 11}, {0, -12}, {0, 13}, {0, -14}, {0, 15}, {0, -16}}, {2, 3}); 730 TestCWiseGradComplex(IMAG, x, dy, dx_expected); 731} 732 733TEST_F(CWiseUnaryComplexGradTest, Conj) { 734 Tensor x = test::AsTensor<complex64>( 735 {{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3}); 736 Tensor dy = test::AsTensor<complex64>( 737 {{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3}); 738 Tensor dx_expected = test::AsTensor<complex64>( 739 {{1, 1}, {-2, -2}, {3, 3}, {-4, -4}, {8, 8}, {-9, -9}}, {2, 3}); 740 TestCWiseGradComplex(CONJ, x, dy, dx_expected); 741} 742 743class MathGradTest : public ::testing::Test { 744 protected: 745 MathGradTest() : root_(Scope::NewRootScope().WithDevice("/cpu:0")) {} 746 747 void TestMatMulGrad(const bool is_batch, const bool t_x, const bool t_y) { 748 // Generate random test data. 749 std::vector<Tensor> data; 750 RandMatMulGradData(is_batch, t_x, t_y, &data); 751 auto x = Const(root_, data[0]); 752 auto y = Const(root_, data[1]); 753 auto dz = Const(root_, data[2]); 754 755 std::vector<Tensor> grad_outputs; 756 ComputeMatMulGrad(is_batch, x, t_x, y, t_y, dz, &grad_outputs); 757 758 if (!t_x && !t_y) { 759 test::ExpectClose(grad_outputs[0], 760 ComputeMatMul(is_batch, dz, false, y, true)); 761 test::ExpectClose(grad_outputs[1], 762 ComputeMatMul(is_batch, x, true, dz, false)); 763 } else if (t_x && !t_y) { 764 test::ExpectClose(grad_outputs[0], 765 ComputeMatMul(is_batch, y, false, dz, true)); 766 test::ExpectClose(grad_outputs[1], 767 ComputeMatMul(is_batch, x, false, dz, false)); 768 } else if (!t_x && t_y) { 769 test::ExpectClose(grad_outputs[0], 770 ComputeMatMul(is_batch, dz, false, y, false)); 771 test::ExpectClose(grad_outputs[1], 772 ComputeMatMul(is_batch, dz, true, x, false)); 773 } else { 774 test::ExpectClose(grad_outputs[0], 775 ComputeMatMul(is_batch, y, true, dz, true)); 776 test::ExpectClose(grad_outputs[1], 777 ComputeMatMul(is_batch, dz, true, x, true)); 778 } 779 } 780 781 void ComputeMatMulGrad(const bool is_batch, const Output& x, const bool t_x, 782 const Output& y, const bool t_y, const Output& dz, 783 std::vector<Tensor>* out) { 784 // Compute forward MatMul: z = MatMul(x, y). 785 Output z; 786 if (is_batch) { 787 z = BatchMatMul(root_, x, y, BatchMatMul::AdjX(t_x).AdjY(t_y)); 788 } else { 789 z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y)); 790 } 791 TF_ASSERT_OK(root_.status()); 792 CHECK_NOTNULL(z.node()); 793 std::vector<Output> grad_outputs; 794 // Call MatMulGrad which populates 'grad_outputs'. 795 TF_ASSERT_OK(test::CallGradFunction(root_, Operation(z.node()), {dz}, 796 &grad_outputs)); 797 ASSERT_EQ(2, grad_outputs.size()); 798 // Run graph and return MatMul gradient tensors for 'dx' and 'dy' in 'out'. 799 test::GetTensors(root_, {grad_outputs[0], grad_outputs[1]}, out); 800 } 801 802 Tensor ComputeMatMul(const bool is_batch, const Output& x, const bool t_x, 803 const Output& y, const bool t_y) { 804 Output z; 805 if (is_batch) { 806 z = BatchMatMul(root_, x, y, BatchMatMul::AdjX(t_x).AdjY(t_y)); 807 } else { 808 z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y)); 809 } 810 TF_EXPECT_OK(root_.status()); 811 Tensor out; 812 test::GetTensor(root_, z, &out); 813 return out; 814 } 815 816 void RandMatMulGradData(const bool is_batch, const bool tx, const bool ty, 817 std::vector<Tensor>* data) { 818 // Choose a random batch size in [1, 4] 819 const int b = 1 + (random::New64() % 4); 820 // z = MatMul(x, y) 821 const int m = Rand(); 822 const int k = Rand(); 823 const int n = Rand(); 824 825 TensorShape x_shape; 826 if (is_batch) { 827 // x.shape = [b, m, k] 828 x_shape = tx ? TensorShape({b, k, m}) : TensorShape({b, m, k}); 829 } else { 830 // x.shape = [m, k] 831 x_shape = tx ? TensorShape({k, m}) : TensorShape({m, k}); 832 } 833 data->emplace_back(DT_FLOAT, x_shape); 834 RandTensor(&data->back()); 835 836 TensorShape y_shape; 837 if (is_batch) { 838 // y.shape = [b, k, n] 839 y_shape = ty ? TensorShape({b, n, k}) : TensorShape({b, k, n}); 840 } else { 841 // y.shape = [k, n] 842 y_shape = ty ? TensorShape({n, k}) : TensorShape({k, n}); 843 } 844 data->emplace_back(DT_FLOAT, y_shape); 845 RandTensor(&data->back()); 846 847 TensorShape z_shape; 848 if (is_batch) { 849 // z.shape = [b, m, n] 850 z_shape = TensorShape({b, m, n}); 851 } else { 852 // z.shape = [m, n] 853 z_shape = TensorShape({m, n}); 854 } 855 data->emplace_back(DT_FLOAT, z_shape); 856 RandTensor(&data->back()); 857 } 858 859 void RandTensor(Tensor* t) { 860 test::FillFn<float>( 861 t, [this](const int i) { return static_cast<float>(Rand()); }); 862 } 863 864 int Rand() { return 1 + (random::New64() % 10); } 865 866 Scope root_; 867}; 868 869TEST_F(MathGradTest, MatMulGrad_NoTranspose) { 870 TestMatMulGrad(false, false, false); 871} 872 873TEST_F(MathGradTest, MatMulGrad_TransposeX) { 874 TestMatMulGrad(false, true, false); 875} 876 877TEST_F(MathGradTest, MatMulGrad_TransposeY) { 878 TestMatMulGrad(false, false, true); 879} 880 881TEST_F(MathGradTest, MatMulGrad_TransposeX_TransposeY) { 882 TestMatMulGrad(false, true, true); 883} 884 885TEST_F(MathGradTest, BatchMatMulGrad_NoTranspose) { 886 TestMatMulGrad(true, false, false); 887} 888 889TEST_F(MathGradTest, BatchMatMulGrad_TransposeX) { 890 TestMatMulGrad(true, true, false); 891} 892 893TEST_F(MathGradTest, BatchMatMulGrad_TransposeY) { 894 TestMatMulGrad(true, false, true); 895} 896 897TEST_F(MathGradTest, BatchMatMulGrad_TransposeX_TransposeY) { 898 TestMatMulGrad(true, true, true); 899} 900 901class NaryGradTest : public ::testing::Test { 902 protected: 903 NaryGradTest() : scope_(Scope::NewRootScope().WithDevice("/cpu:0")) {} 904 905 void RunTest(const OutputList& xs, const std::vector<TensorShape>& x_shapes, 906 const OutputList& ys, const std::vector<TensorShape>& y_shapes) { 907 TF_ASSERT_OK(scope_.status()); 908 float max_error; 909 TF_ASSERT_OK( 910 ComputeGradientError(scope_, xs, x_shapes, ys, y_shapes, &max_error)); 911 EXPECT_LT(max_error, 1e-3); 912 } 913 914 Scope scope_; 915}; 916 917TEST_F(NaryGradTest, AddN) { 918 TensorShape shape({3, 2, 5}); 919 std::vector<Output> xs; 920 xs.push_back(Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape))); 921 xs.push_back(Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape))); 922 xs.push_back(Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape))); 923 auto y = AddN(scope_, xs); 924 RunTest(xs, {shape, shape, shape}, {y}, {shape}); 925} 926 927TEST_F(NaryGradTest, Add) { 928 TensorShape x1_shape({3, 2, 5}); 929 TensorShape x2_shape({2, 5}); 930 auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x1_shape)); 931 auto x2 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x2_shape)); 932 auto y = Add(scope_, x1, x2); 933 RunTest({x1, x2}, {x1_shape, x2_shape}, {y}, {x1_shape}); 934} 935 936TEST_F(NaryGradTest, Sub) { 937 TensorShape x1_shape({3, 2, 5}); 938 TensorShape x2_shape({2, 5}); 939 auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x1_shape)); 940 auto x2 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x2_shape)); 941 auto y = Sub(scope_, x1, x2); 942 RunTest({x1, x2}, {x1_shape, x2_shape}, {y}, {x1_shape}); 943} 944 945TEST_F(NaryGradTest, Mul) { 946 TensorShape x1_shape({3, 2, 5}); 947 TensorShape x2_shape({2, 5}); 948 auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x1_shape)); 949 auto x2 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x2_shape)); 950 auto y = Mul(scope_, x1, x2); 951 RunTest({x1, x2}, {x1_shape, x2_shape}, {y}, {x1_shape}); 952} 953 954TEST_F(NaryGradTest, Div) { 955 TensorShape x_shape({3, 2, 5}); 956 auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); 957 // Test x / (1 + |x|) rather than x_1 / x_2 to avoid triggering large 958 // division errors in the numeric estimator used by the gradient checker. 959 auto y = Div(scope_, x, Add(scope_, Const<float>(scope_, 1), Abs(scope_, x))); 960 RunTest({x}, {x_shape}, {y}, {x_shape}); 961} 962 963TEST_F(NaryGradTest, RealDiv) { 964 TensorShape x_shape({3, 2, 5}); 965 auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); 966 // Test x / (1 + |x|) rather than x_1 / x_2 to avoid triggering large 967 // division errors in the numeric estimator used by the gradient checker. 968 auto y = 969 RealDiv(scope_, x, Add(scope_, Const<float>(scope_, 1), Abs(scope_, x))); 970 RunTest({x}, {x_shape}, {y}, {x_shape}); 971} 972 973TEST_F(NaryGradTest, SquaredDifference) { 974 TensorShape x1_shape({3, 2, 5}); 975 TensorShape x2_shape({2, 5}); 976 auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x1_shape)); 977 auto x2 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x2_shape)); 978 auto y = SquaredDifference(scope_, x1, x2); 979 RunTest({x1, x2}, {x1_shape, x2_shape}, {y}, {x1_shape}); 980} 981 982} // namespace 983} // namespace tensorflow 984