math_grad_test.cc revision 5a1d6d9dac79b46f055462ee52125753524d9f6e
1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include "tensorflow/cc/framework/grad_op_registry.h" 17#include "tensorflow/cc/framework/gradient_checker.h" 18#include "tensorflow/cc/framework/testutil.h" 19#include "tensorflow/cc/gradients/grad_testutil.h" 20#include "tensorflow/cc/ops/standard_ops.h" 21#include "tensorflow/core/framework/tensor_testutil.h" 22#include "tensorflow/core/lib/core/status_test_util.h" 23#include "tensorflow/core/lib/random/random.h" 24 25namespace tensorflow { 26using namespace ops; // NOLINT(build/namespaces) 27 28namespace { 29 30// TODO(andydavis) Test gradient function against numeric gradients output. 31// TODO(andydavis) As more gradients are added move common test functions 32// to a testutil library. 33 34class CWiseUnaryGradTest : public ::testing::Test { 35 protected: 36 CWiseUnaryGradTest() : scope_(Scope::NewRootScope().WithDevice("/cpu:0")) {} 37 38 enum UnaryOpType { 39 ABS, 40 NEG, 41 INV, 42 SQUARE, 43 SQRT, 44 RSQRT, 45 EXP, 46 EXPM1, 47 LOG, 48 LOG1P, 49 SINH, 50 COSH, 51 TANH, 52 ASINH, 53 ACOSH, 54 ATANH, 55 SIGMOID, 56 SIGN, 57 SIN, 58 COS, 59 ASIN, 60 ACOS, 61 TAN, 62 ATAN 63 }; 64 65 template <typename T> 66 void TestCWiseGrad(UnaryOpType op_type, const std::function<T(int)>& x_fn, 67 const std::function<T(const T&)>& dy_fn, 68 const std::function<T(const T&, const T&)>& dx_fn) { 69 DataType dtype = DataTypeToEnum<T>::v(); 70 Tensor x(dtype, {2, 3, 2}); 71 auto x_flat = x.flat<T>(); 72 for (int i = 0; i < x_flat.size(); ++i) { 73 x_flat(i) = x_fn(i); 74 } 75 76 Tensor dy(dtype, {2, 3, 2}); 77 auto dy_flat = dy.flat<T>(); 78 for (int i = 0; i < dy_flat.size(); ++i) { 79 dy_flat(i) = dy_fn(x_flat(i)); 80 } 81 82 Tensor dx(dtype, {2, 3, 2}); 83 auto dx_flat = dx.flat<T>(); 84 for (int i = 0; i < dx_flat.size(); ++i) { 85 dx_flat(i) = dx_fn(x_flat(i), dy_flat(i)); 86 } 87 88 Output y; 89 switch (op_type) { 90 case ABS: 91 y = Abs(scope_, x); 92 break; 93 case NEG: 94 y = Neg(scope_, x); 95 break; 96 case INV: 97 y = Reciprocal(scope_, x); 98 break; 99 case SQUARE: 100 y = Square(scope_, x); 101 break; 102 case SQRT: 103 y = Sqrt(scope_, x); 104 break; 105 case RSQRT: 106 y = Rsqrt(scope_, x); 107 break; 108 case EXP: 109 y = Exp(scope_, x); 110 break; 111 case EXPM1: 112 y = Expm1(scope_, x); 113 break; 114 case LOG: 115 y = Log(scope_, x); 116 break; 117 case LOG1P: 118 y = Log1p(scope_, x); 119 break; 120 case SINH: 121 y = Sinh(scope_, x); 122 break; 123 case COSH: 124 y = Cosh(scope_, x); 125 break; 126 case TANH: 127 y = Tanh(scope_, x); 128 break; 129 case ASINH: 130 y = Asinh(scope_, x); 131 break; 132 case ACOSH: 133 y = Acosh(scope_, x); 134 break; 135 case ATANH: 136 y = Atanh(scope_, x); 137 break; 138 case SIGMOID: 139 y = Sigmoid(scope_, x); 140 break; 141 case SIGN: 142 y = Sign(scope_, x); 143 break; 144 case SIN: 145 y = Sin(scope_, x); 146 break; 147 case COS: 148 y = Cos(scope_, x); 149 break; 150 case ASIN: 151 y = Asin(scope_, x); 152 break; 153 case ACOS: 154 y = Acos(scope_, x); 155 break; 156 case TAN: 157 y = Tan(scope_, x); 158 break; 159 case ATAN: 160 y = Atan(scope_, x); 161 break; 162 } 163 164 std::vector<Output> grad_outputs; 165 TF_ASSERT_OK(test::CallGradFunction( 166 scope_, Operation(y.node()), {ops::Const(scope_, dy)}, &grad_outputs)); 167 Tensor output; 168 test::GetTensor(scope_, grad_outputs[0], &output); 169 test::ExpectClose(output, dx); 170 } 171 172 float RV(const std::vector<float>& v) { 173 return v[random::New64() % v.size()]; 174 } 175 176 complex64 CRV(const std::vector<complex64>& v) { 177 return v[random::New64() % v.size()]; 178 } 179 180 complex64 conjugate(const complex64& val) { 181 return complex64(val.real(), -val.imag()); 182 } 183 184 const complex64 one_{1.0, 0}; 185 186 Scope scope_; 187}; 188 189TEST_F(CWiseUnaryGradTest, Abs) { 190 auto x_fn = [this](const int i) { return RV({-1, 0, 1}); }; 191 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 192 auto dx_fn = [this](const float x, const float dy) { return x * dy; }; 193 TestCWiseGrad<float>(ABS, x_fn, dy_fn, dx_fn); 194} 195 196TEST_F(CWiseUnaryGradTest, Neg) { 197 auto x_fn = [this](const int i) { return RV({-1, 0, 1}); }; 198 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 199 auto dx_fn = [this](const float x, const float dy) { return -dy; }; 200 TestCWiseGrad<float>(NEG, x_fn, dy_fn, dx_fn); 201} 202 203TEST_F(CWiseUnaryGradTest, Reciprocal) { 204 auto x_fn = [this](const int i) { return RV({-1, 1, -2, 2, -3, 3, -4, 4}); }; 205 auto dy_fn = [this](const float x) { return RV({0, -2, 2, -3, 3, -4, 4}); }; 206 auto dx_fn = [this](const float x, const float dy) { 207 return -(1 / (x * x)) * dy; 208 }; 209 TestCWiseGrad<float>(INV, x_fn, dy_fn, dx_fn); 210} 211 212TEST_F(CWiseUnaryGradTest, Reciprocal_Complex) { 213 auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); }; 214 auto dy_fn = [this](const complex64 x) { 215 return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); 216 }; 217 auto dx_fn = [this](const complex64 x, const complex64 dy) { 218 return -conjugate(one_ / (x * x)) * dy; 219 }; 220 TestCWiseGrad<complex64>(INV, x_fn, dy_fn, dx_fn); 221} 222 223TEST_F(CWiseUnaryGradTest, Square) { 224 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 225 auto dy_fn = [this](const float x) { return RV({0, -7, 7, -8, 8, -9, 9}); }; 226 auto dx_fn = [this](const float x, const float dy) { return 2 * x * dy; }; 227 TestCWiseGrad<float>(SQUARE, x_fn, dy_fn, dx_fn); 228} 229 230TEST_F(CWiseUnaryGradTest, Square_Complex) { 231 auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); }; 232 auto dy_fn = [this](const complex64& x) { 233 return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); 234 }; 235 auto dx_fn = [this](const complex64& x, const complex64& dy) { 236 return conjugate(complex64(2, 0) * x) * dy; 237 }; 238 TestCWiseGrad<complex64>(SQUARE, x_fn, dy_fn, dx_fn); 239} 240 241TEST_F(CWiseUnaryGradTest, Sqrt) { 242 auto x_fn = [this](const int i) { return RV({0, 1, 2, 3, 4, 5, 6, 7}); }; 243 auto dy_fn = [this](const float x) { return x + RV({8, 9, 10, 11, 12, 13}); }; 244 auto dx_fn = [this](const float x, const float dy) { 245 return dy * 0.5 * (1.0 / std::sqrt(x)); 246 }; 247 TestCWiseGrad<float>(SQRT, x_fn, dy_fn, dx_fn); 248} 249 250TEST_F(CWiseUnaryGradTest, Sqrt_Complex) { 251 auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); }; 252 auto dy_fn = [this](const complex64& x) { 253 return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); 254 }; 255 auto dx_fn = [this](const complex64& x, const complex64& dy) { 256 return conjugate(complex64(0.5, 0) / std::sqrt(x)) * dy; 257 }; 258 TestCWiseGrad<complex64>(SQRT, x_fn, dy_fn, dx_fn); 259} 260 261TEST_F(CWiseUnaryGradTest, Rsqrt) { 262 auto x_fn = [this](const int i) { return RV({1, 2, 3, 4, 5, 6, 7, 8}); }; 263 auto dy_fn = [this](const float x) { return x + RV({8, 9, 10, 11, 12, 13}); }; 264 auto dx_fn = [this](const float x, const float dy) { 265 return dy * -0.5 * (1 / std::sqrt(x)) * (1 / x); 266 }; 267 TestCWiseGrad<float>(RSQRT, x_fn, dy_fn, dx_fn); 268} 269 270TEST_F(CWiseUnaryGradTest, Rsqrt_Complex) { 271 auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); }; 272 auto dy_fn = [this](const complex64& x) { 273 return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); 274 }; 275 auto dx_fn = [this](const complex64& x, const complex64& dy) { 276 return conjugate(complex64(-0.5, 0) / std::sqrt(x) / x) * dy; 277 }; 278 TestCWiseGrad<complex64>(RSQRT, x_fn, dy_fn, dx_fn); 279} 280 281TEST_F(CWiseUnaryGradTest, Exp) { 282 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 283 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 284 auto dx_fn = [this](const float x, const float dy) { 285 return dy * std::exp(x); 286 }; 287 TestCWiseGrad<float>(EXP, x_fn, dy_fn, dx_fn); 288} 289 290TEST_F(CWiseUnaryGradTest, Exp_Complex) { 291 auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); }; 292 auto dy_fn = [this](const complex64& x) { 293 return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); 294 }; 295 auto dx_fn = [this](const complex64& x, const complex64& dy) { 296 return dy * conjugate(std::exp(x)); 297 }; 298 TestCWiseGrad<complex64>(EXP, x_fn, dy_fn, dx_fn); 299} 300 301TEST_F(CWiseUnaryGradTest, Expm1) { 302 auto x_fn = [this](const int i) { return RV({0, -1, 1e-6, 1, -2, 3, 100}); }; 303 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 304 auto dx_fn = [this](const float x, const float dy) { 305 return dy * std::exp(x); 306 }; 307 TestCWiseGrad<float>(EXPM1, x_fn, dy_fn, dx_fn); 308} 309 310TEST_F(CWiseUnaryGradTest, Expm1_Complex) { 311 auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); }; 312 auto dy_fn = [this](const complex64& x) { 313 return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); 314 }; 315 auto dx_fn = [this](const complex64& x, const complex64& dy) { 316 return dy * conjugate(std::exp(x)); 317 }; 318 TestCWiseGrad<complex64>(EXPM1, x_fn, dy_fn, dx_fn); 319} 320 321TEST_F(CWiseUnaryGradTest, Log) { 322 auto x_fn = [this](const int i) { return RV({-1, 1, -2, 2, -3, 3, -4, 4}); }; 323 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 324 auto dx_fn = [this](const float x, const float dy) { return dy * (1.0 / x); }; 325 TestCWiseGrad<float>(LOG, x_fn, dy_fn, dx_fn); 326} 327 328TEST_F(CWiseUnaryGradTest, Log_Complex) { 329 auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); }; 330 auto dy_fn = [this](const complex64& x) { 331 return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); 332 }; 333 auto dx_fn = [this](const complex64& x, const complex64& dy) { 334 return dy * conjugate(one_ / x); 335 }; 336 TestCWiseGrad<complex64>(LOG, x_fn, dy_fn, dx_fn); 337} 338 339TEST_F(CWiseUnaryGradTest, Log1p) { 340 auto x_fn = [this](const int i) { return RV({0, 1e-6, 1, 2, 3, 4, 100}); }; 341 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 342 auto dx_fn = [this](const float x, const float dy) { 343 return dy * (1.0 / (1.0 + x)); 344 }; 345 TestCWiseGrad<float>(LOG1P, x_fn, dy_fn, dx_fn); 346} 347 348TEST_F(CWiseUnaryGradTest, Log1p_Complex) { 349 auto x_fn = [this](const int i) { 350 return CRV({{0, 0}, {1e-6, 0}, {2, -1}, {1, 2}, {3, 4}}); 351 }; 352 auto dy_fn = [this](const complex64& x) { 353 return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); 354 }; 355 auto dx_fn = [this](const complex64& x, const complex64& dy) { 356 return dy / (one_ + conjugate(x)); 357 }; 358 TestCWiseGrad<complex64>(LOG1P, x_fn, dy_fn, dx_fn); 359} 360 361TEST_F(CWiseUnaryGradTest, Sinh) { 362 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 363 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 364 auto dx_fn = [this](const float x, const float dy) { 365 return dy * std::cosh(x); 366 }; 367 TestCWiseGrad<float>(SINH, x_fn, dy_fn, dx_fn); 368} 369 370TEST_F(CWiseUnaryGradTest, Sinh_Complex) { 371 auto x_fn = [this](const int i) { 372 return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); 373 }; 374 auto dy_fn = [this](const complex64& x) { 375 return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); 376 }; 377 auto dx_fn = [this](const complex64& x, const complex64& dy) { 378 return dy * conjugate(std::cosh(x)); 379 }; 380 TestCWiseGrad<complex64>(SINH, x_fn, dy_fn, dx_fn); 381} 382 383TEST_F(CWiseUnaryGradTest, Cosh) { 384 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 385 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 386 auto dx_fn = [this](const float x, const float dy) { 387 return dy * std::sinh(x); 388 }; 389 TestCWiseGrad<float>(COSH, x_fn, dy_fn, dx_fn); 390} 391 392TEST_F(CWiseUnaryGradTest, Cosh_Complex) { 393 auto x_fn = [this](const int i) { 394 return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); 395 }; 396 auto dy_fn = [this](const complex64& x) { 397 return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); 398 }; 399 auto dx_fn = [this](const complex64& x, const complex64& dy) { 400 return dy * conjugate(std::sinh(x)); 401 }; 402 TestCWiseGrad<complex64>(COSH, x_fn, dy_fn, dx_fn); 403} 404 405TEST_F(CWiseUnaryGradTest, Tanh) { 406 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 407 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 408 auto dx_fn = [this](const float x, const float dy) { 409 const float y = std::tanh(x); 410 return dy * (1.0 - y * y); 411 }; 412 TestCWiseGrad<float>(TANH, x_fn, dy_fn, dx_fn); 413} 414 415TEST_F(CWiseUnaryGradTest, Tanh_Complex) { 416 auto x_fn = [this](const int i) { 417 return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); 418 }; 419 auto dy_fn = [this](const complex64& x) { 420 return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); 421 }; 422 auto dx_fn = [this](const complex64& x, const complex64& dy) { 423 const complex64 y = std::tanh(x); 424 return dy * conjugate((one_ - y * y)); 425 }; 426 TestCWiseGrad<complex64>(TANH, x_fn, dy_fn, dx_fn); 427} 428 429TEST_F(CWiseUnaryGradTest, Asinh) { 430 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 431 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 432 auto dx_fn = [this](const float x, const float dy) { 433 auto y = std::asinh(x); 434 return dy / std::cosh(y); 435 }; 436 TestCWiseGrad<float>(ASINH, x_fn, dy_fn, dx_fn); 437} 438 439TEST_F(CWiseUnaryGradTest, Asinh_Complex) { 440 auto x_fn = [this](const int i) { 441 return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); 442 }; 443 auto dy_fn = [this](const complex64& x) { 444 return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); 445 }; 446 auto dx_fn = [this](const complex64& x, const complex64& dy) { 447 auto y = std::asinh(x); 448 return dy / conjugate(std::cosh(y)); 449 }; 450 TestCWiseGrad<complex64>(ASINH, x_fn, dy_fn, dx_fn); 451} 452 453TEST_F(CWiseUnaryGradTest, Acosh) { 454 auto x_fn = [this](const int i) { return RV({1, 2, 3, 4, 5, 6, 7}); }; 455 auto dy_fn = [this](const float x) { 456 return x + RV({8, 9, 10, 11, 12, 13, 14}); 457 }; 458 auto dx_fn = [this](const float x, const float dy) { 459 auto y = std::acosh(x); 460 return dy / std::sinh(y); 461 }; 462 TestCWiseGrad<float>(ACOSH, x_fn, dy_fn, dx_fn); 463} 464 465TEST_F(CWiseUnaryGradTest, Acosh_Complex) { 466 auto x_fn = [this](const int i) { 467 return CRV({{1, 1}, {2, 1}, {1, 4}, {1, 2}, {3, 4}}); 468 }; 469 auto dy_fn = [this](const complex64& x) { 470 return x + CRV({{2, 2}, {3, 3}, {1, 4}}); 471 }; 472 auto dx_fn = [this](const complex64& x, const complex64& dy) { 473 auto y = std::acosh(x); 474 return dy / conjugate(std::sinh(y)); 475 }; 476 TestCWiseGrad<complex64>(ACOSH, x_fn, dy_fn, dx_fn); 477} 478 479TEST_F(CWiseUnaryGradTest, Atanh) { 480 auto x_fn = [this](const int i) { return RV({0, -0.5, 0.5, -0.1, 0.1}); }; 481 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 482 auto dx_fn = [this](const float x, const float dy) { 483 return dy * (1. / (1. - x * x)); 484 }; 485 TestCWiseGrad<float>(ATANH, x_fn, dy_fn, dx_fn); 486} 487 488TEST_F(CWiseUnaryGradTest, Atanh_Complex) { 489 auto x_fn = [this](const int i) { 490 return CRV({{0.1, 0}, {0, 0.1}, {0.2, -0.1}, {0.1, 0.2}, {0.3, 0.4}}); 491 }; 492 auto dy_fn = [this](const complex64& x) { 493 return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); 494 }; 495 auto dx_fn = [this](const complex64& x, const complex64& dy) { 496 return dy / conjugate(one_ - x * x); 497 }; 498 TestCWiseGrad<complex64>(ATANH, x_fn, dy_fn, dx_fn); 499} 500 501TEST_F(CWiseUnaryGradTest, Sigmoid) { 502 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 503 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 504 auto dx_fn = [this](const float x, const float dy) { 505 const float y = 1.0 / (1.0 + std::exp(-x)); 506 return dy * y * (1.0 - y); 507 }; 508 TestCWiseGrad<float>(SIGMOID, x_fn, dy_fn, dx_fn); 509} 510 511TEST_F(CWiseUnaryGradTest, Sigmoid_Complex) { 512 auto x_fn = [this](const int i) { 513 return CRV({{1, 0}, {0, 0}, {2, -1}, {1, 2}, {3, 4}}); 514 }; 515 auto dy_fn = [this](const complex64& x) { 516 return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); 517 }; 518 auto dx_fn = [this](const complex64& x, const complex64& dy) { 519 const complex64 y = one_ / (one_ + std::exp(-x)); 520 return dy * conjugate(y * (one_ - y)); 521 }; 522 TestCWiseGrad<complex64>(SIGMOID, x_fn, dy_fn, dx_fn); 523} 524 525TEST_F(CWiseUnaryGradTest, Sign) { 526 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 527 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 528 auto dx_fn = [this](const float x, const float dy) { return 0.0; }; 529 TestCWiseGrad<float>(SIGN, x_fn, dy_fn, dx_fn); 530} 531 532TEST_F(CWiseUnaryGradTest, Sin) { 533 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 534 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 535 auto dx_fn = [this](const float x, const float dy) { 536 return dy * std::cos(x); 537 }; 538 TestCWiseGrad<float>(SIN, x_fn, dy_fn, dx_fn); 539} 540 541TEST_F(CWiseUnaryGradTest, Sin_Complex) { 542 auto x_fn = [this](const int i) { 543 return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); 544 }; 545 auto dy_fn = [this](const complex64& x) { 546 return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); 547 }; 548 auto dx_fn = [this](const complex64& x, const complex64& dy) { 549 return dy * conjugate(std::cos(x)); 550 }; 551 TestCWiseGrad<complex64>(SIN, x_fn, dy_fn, dx_fn); 552} 553 554TEST_F(CWiseUnaryGradTest, Cos) { 555 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 556 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 557 auto dx_fn = [this](const float x, const float dy) { 558 return dy * -1.0 * std::sin(x); 559 }; 560 TestCWiseGrad<float>(COS, x_fn, dy_fn, dx_fn); 561} 562 563TEST_F(CWiseUnaryGradTest, Cos_Complex) { 564 auto x_fn = [this](const int i) { 565 return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); 566 }; 567 auto dy_fn = [this](const complex64& x) { 568 return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); 569 }; 570 auto dx_fn = [this](const complex64& x, const complex64& dy) { 571 return dy * conjugate(-std::sin(x)); 572 }; 573 TestCWiseGrad<complex64>(COS, x_fn, dy_fn, dx_fn); 574} 575 576TEST_F(CWiseUnaryGradTest, Asin) { 577 auto x_fn = [this](const int i) { return RV({0, -0.5, 0.5, -1, 1}); }; 578 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 579 auto dx_fn = [this](const float x, const float dy) { 580 return dy * (1.0 / std::sqrt(1.0 - x * x)); 581 }; 582 TestCWiseGrad<float>(ASIN, x_fn, dy_fn, dx_fn); 583} 584 585TEST_F(CWiseUnaryGradTest, Asin_Complex) { 586 auto x_fn = [this](const int i) { 587 return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); 588 }; 589 auto dy_fn = [this](const complex64& x) { 590 return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); 591 }; 592 auto dx_fn = [this](const complex64& x, const complex64& dy) { 593 return dy / conjugate(std::sqrt(one_ - x * x)); 594 }; 595 // TODO(kbsriram) 596 // Enable test when the asin kernel supports complex numbers 597 if (false) { 598 TestCWiseGrad<complex64>(ASIN, x_fn, dy_fn, dx_fn); 599 } 600} 601 602TEST_F(CWiseUnaryGradTest, Acos) { 603 auto x_fn = [this](const int i) { return RV({0, -0.5, 0.5, -1, 1}); }; 604 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 605 auto dx_fn = [this](const float x, const float dy) { 606 return dy * (-1.0 / std::sqrt(1.0 - x * x)); 607 }; 608 TestCWiseGrad<float>(ACOS, x_fn, dy_fn, dx_fn); 609} 610 611TEST_F(CWiseUnaryGradTest, Acos_Complex) { 612 auto x_fn = [this](const int i) { 613 return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); 614 }; 615 auto dy_fn = [this](const complex64& x) { 616 return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); 617 }; 618 auto dx_fn = [this](const complex64& x, const complex64& dy) { 619 return dy / -conjugate(std::sqrt(one_ - x * x)); 620 }; 621 // TODO(kbsriram) 622 // Add test when the acos kernel supports complex numbers 623 if (false) { 624 TestCWiseGrad<complex64>(ACOS, x_fn, dy_fn, dx_fn); 625 } 626} 627 628TEST_F(CWiseUnaryGradTest, Tan) { 629 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 630 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 631 auto dx_fn = [this](const float x, const float dy) { 632 const float cosx = std::cos(x); 633 return dy * (1 / (cosx * cosx)); 634 }; 635 TestCWiseGrad<float>(TAN, x_fn, dy_fn, dx_fn); 636} 637 638TEST_F(CWiseUnaryGradTest, Tan_Complex) { 639 auto x_fn = [this](const int i) { 640 return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); 641 }; 642 auto dy_fn = [this](const complex64& x) { 643 return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); 644 }; 645 auto dx_fn = [this](const complex64& x, const complex64& dy) { 646 const complex64 cosx = std::cos(x); 647 return dy / conjugate(cosx * cosx); 648 }; 649 // TODO(kbsriram) 650 // Enable when tan kernel supports complex inputs 651 if (false) { 652 TestCWiseGrad<complex64>(TAN, x_fn, dy_fn, dx_fn); 653 } 654} 655 656TEST_F(CWiseUnaryGradTest, Atan) { 657 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 658 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 659 auto dx_fn = [this](const float x, const float dy) { 660 return dy * (1 / (1 + x * x)); 661 }; 662 TestCWiseGrad<float>(ATAN, x_fn, dy_fn, dx_fn); 663} 664 665TEST_F(CWiseUnaryGradTest, Atan_Complex) { 666 auto x_fn = [this](const int i) { 667 return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); 668 }; 669 auto dy_fn = [this](const complex64& x) { 670 return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); 671 }; 672 auto dx_fn = [this](const complex64& x, const complex64& dy) { 673 return dy / (one_ + x * x); 674 }; 675 // TODO(kbsriram) 676 // Add test when the atan kernel supports complex numbers 677 if (false) { 678 TestCWiseGrad<complex64>(ATAN, x_fn, dy_fn, dx_fn); 679 } 680} 681 682class CWiseUnaryComplexGradTest : public ::testing::Test { 683 protected: 684 CWiseUnaryComplexGradTest() 685 : scope_(Scope::NewRootScope().WithDevice("/cpu:0")) {} 686 687 enum UnaryOpType { REAL, IMAG, ANGLE, CONJ }; 688 689 void TestCWiseGradComplex(UnaryOpType op_type, const Tensor& x, 690 const Tensor& dy, const Tensor& dx_expected) { 691 Output y; 692 switch (op_type) { 693 case REAL: 694 y = Real(scope_, x); 695 break; 696 case IMAG: 697 y = Imag(scope_, x); 698 break; 699 case ANGLE: 700 y = Angle(scope_, x); 701 break; 702 case CONJ: 703 y = Conj(scope_, x); 704 break; 705 } 706 707 std::vector<Output> grad_outputs; 708 TF_ASSERT_OK(test::CallGradFunction( 709 scope_, Operation(y.node()), {ops::Const(scope_, dy)}, &grad_outputs)); 710 Tensor dx; 711 test::GetTensor(scope_, grad_outputs[0], &dx); 712 test::ExpectClose(dx, dx_expected); 713 } 714 715 Scope scope_; 716}; 717 718TEST_F(CWiseUnaryComplexGradTest, Real) { 719 Tensor x = test::AsTensor<complex64>( 720 {{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3}); 721 Tensor dy = test::AsTensor<float>({11, -12, 13, -14, 15, -16}, {2, 3}); 722 Tensor dx_expected = test::AsTensor<complex64>( 723 {{11, 0}, {-12, 0}, {13, 0}, {-14, 0}, {15, 0}, {-16, 0}}, {2, 3}); 724 TestCWiseGradComplex(REAL, x, dy, dx_expected); 725} 726 727TEST_F(CWiseUnaryComplexGradTest, Imag) { 728 Tensor x = test::AsTensor<complex64>( 729 {{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3}); 730 Tensor dy = test::AsTensor<float>({11, -12, 13, -14, 15, -16}, {2, 3}); 731 Tensor dx_expected = test::AsTensor<complex64>( 732 {{0, 11}, {0, -12}, {0, 13}, {0, -14}, {0, 15}, {0, -16}}, {2, 3}); 733 TestCWiseGradComplex(IMAG, x, dy, dx_expected); 734} 735 736TEST_F(CWiseUnaryComplexGradTest, Angle) { 737 Tensor x = test::AsTensor<complex64>( 738 {{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3}); 739 Tensor dy = test::AsTensor<float>({11, -12, 13, -14, 15, -16}, {2, 3}); 740 Tensor dx_expected = test::AsTensor<complex64>( 741 {{5.5, 5.5}, {3, 3}, 742 {2.1666666666666665, 2.1666666666666665}, {1.75, 1.75}, 743 {0.9375, 0.9375}, {0.8888888888888888, 0.8888888888888888}}, {2, 3}); 744 TestCWiseGradComplex(ANGLE, x, dy, dx_expected); 745} 746 747TEST_F(CWiseUnaryComplexGradTest, Conj) { 748 Tensor x = test::AsTensor<complex64>( 749 {{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3}); 750 Tensor dy = test::AsTensor<complex64>( 751 {{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3}); 752 Tensor dx_expected = test::AsTensor<complex64>( 753 {{1, 1}, {-2, -2}, {3, 3}, {-4, -4}, {8, 8}, {-9, -9}}, {2, 3}); 754 TestCWiseGradComplex(CONJ, x, dy, dx_expected); 755} 756 757class MathGradTest : public ::testing::Test { 758 protected: 759 MathGradTest() : root_(Scope::NewRootScope().WithDevice("/cpu:0")) {} 760 761 void TestMatMulGrad(const bool is_batch, const bool t_x, const bool t_y) { 762 // Generate random test data. 763 std::vector<Tensor> data; 764 RandMatMulGradData(is_batch, t_x, t_y, &data); 765 auto x = Const(root_, data[0]); 766 auto y = Const(root_, data[1]); 767 auto dz = Const(root_, data[2]); 768 769 std::vector<Tensor> grad_outputs; 770 ComputeMatMulGrad(is_batch, x, t_x, y, t_y, dz, &grad_outputs); 771 772 if (!t_x && !t_y) { 773 test::ExpectClose(grad_outputs[0], 774 ComputeMatMul(is_batch, dz, false, y, true)); 775 test::ExpectClose(grad_outputs[1], 776 ComputeMatMul(is_batch, x, true, dz, false)); 777 } else if (t_x && !t_y) { 778 test::ExpectClose(grad_outputs[0], 779 ComputeMatMul(is_batch, y, false, dz, true)); 780 test::ExpectClose(grad_outputs[1], 781 ComputeMatMul(is_batch, x, false, dz, false)); 782 } else if (!t_x && t_y) { 783 test::ExpectClose(grad_outputs[0], 784 ComputeMatMul(is_batch, dz, false, y, false)); 785 test::ExpectClose(grad_outputs[1], 786 ComputeMatMul(is_batch, dz, true, x, false)); 787 } else { 788 test::ExpectClose(grad_outputs[0], 789 ComputeMatMul(is_batch, y, true, dz, true)); 790 test::ExpectClose(grad_outputs[1], 791 ComputeMatMul(is_batch, dz, true, x, true)); 792 } 793 } 794 795 void ComputeMatMulGrad(const bool is_batch, const Output& x, const bool t_x, 796 const Output& y, const bool t_y, const Output& dz, 797 std::vector<Tensor>* out) { 798 // Compute forward MatMul: z = MatMul(x, y). 799 Output z; 800 if (is_batch) { 801 z = BatchMatMul(root_, x, y, BatchMatMul::AdjX(t_x).AdjY(t_y)); 802 } else { 803 z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y)); 804 } 805 TF_ASSERT_OK(root_.status()); 806 CHECK_NOTNULL(z.node()); 807 std::vector<Output> grad_outputs; 808 // Call MatMulGrad which populates 'grad_outputs'. 809 TF_ASSERT_OK(test::CallGradFunction(root_, Operation(z.node()), {dz}, 810 &grad_outputs)); 811 ASSERT_EQ(2, grad_outputs.size()); 812 // Run graph and return MatMul gradient tensors for 'dx' and 'dy' in 'out'. 813 test::GetTensors(root_, {grad_outputs[0], grad_outputs[1]}, out); 814 } 815 816 Tensor ComputeMatMul(const bool is_batch, const Output& x, const bool t_x, 817 const Output& y, const bool t_y) { 818 Output z; 819 if (is_batch) { 820 z = BatchMatMul(root_, x, y, BatchMatMul::AdjX(t_x).AdjY(t_y)); 821 } else { 822 z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y)); 823 } 824 TF_EXPECT_OK(root_.status()); 825 Tensor out; 826 test::GetTensor(root_, z, &out); 827 return out; 828 } 829 830 void RandMatMulGradData(const bool is_batch, const bool tx, const bool ty, 831 std::vector<Tensor>* data) { 832 // Choose a random batch size in [1, 4] 833 const int b = 1 + (random::New64() % 4); 834 // z = MatMul(x, y) 835 const int m = Rand(); 836 const int k = Rand(); 837 const int n = Rand(); 838 839 TensorShape x_shape; 840 if (is_batch) { 841 // x.shape = [b, m, k] 842 x_shape = tx ? TensorShape({b, k, m}) : TensorShape({b, m, k}); 843 } else { 844 // x.shape = [m, k] 845 x_shape = tx ? TensorShape({k, m}) : TensorShape({m, k}); 846 } 847 data->emplace_back(DT_FLOAT, x_shape); 848 RandTensor(&data->back()); 849 850 TensorShape y_shape; 851 if (is_batch) { 852 // y.shape = [b, k, n] 853 y_shape = ty ? TensorShape({b, n, k}) : TensorShape({b, k, n}); 854 } else { 855 // y.shape = [k, n] 856 y_shape = ty ? TensorShape({n, k}) : TensorShape({k, n}); 857 } 858 data->emplace_back(DT_FLOAT, y_shape); 859 RandTensor(&data->back()); 860 861 TensorShape z_shape; 862 if (is_batch) { 863 // z.shape = [b, m, n] 864 z_shape = TensorShape({b, m, n}); 865 } else { 866 // z.shape = [m, n] 867 z_shape = TensorShape({m, n}); 868 } 869 data->emplace_back(DT_FLOAT, z_shape); 870 RandTensor(&data->back()); 871 } 872 873 void RandTensor(Tensor* t) { 874 test::FillFn<float>( 875 t, [this](const int i) { return static_cast<float>(Rand()); }); 876 } 877 878 int Rand() { return 1 + (random::New64() % 10); } 879 880 Scope root_; 881}; 882 883TEST_F(MathGradTest, MatMulGrad_NoTranspose) { 884 TestMatMulGrad(false, false, false); 885} 886 887TEST_F(MathGradTest, MatMulGrad_TransposeX) { 888 TestMatMulGrad(false, true, false); 889} 890 891TEST_F(MathGradTest, MatMulGrad_TransposeY) { 892 TestMatMulGrad(false, false, true); 893} 894 895TEST_F(MathGradTest, MatMulGrad_TransposeX_TransposeY) { 896 TestMatMulGrad(false, true, true); 897} 898 899TEST_F(MathGradTest, BatchMatMulGrad_NoTranspose) { 900 TestMatMulGrad(true, false, false); 901} 902 903TEST_F(MathGradTest, BatchMatMulGrad_TransposeX) { 904 TestMatMulGrad(true, true, false); 905} 906 907TEST_F(MathGradTest, BatchMatMulGrad_TransposeY) { 908 TestMatMulGrad(true, false, true); 909} 910 911TEST_F(MathGradTest, BatchMatMulGrad_TransposeX_TransposeY) { 912 TestMatMulGrad(true, true, true); 913} 914 915class NaryGradTest : public ::testing::Test { 916 protected: 917 NaryGradTest() : scope_(Scope::NewRootScope().WithDevice("/cpu:0")) {} 918 919 void RunTest(const OutputList& xs, const std::vector<TensorShape>& x_shapes, 920 const OutputList& ys, const std::vector<TensorShape>& y_shapes) { 921 TF_ASSERT_OK(scope_.status()); 922 float max_error; 923 TF_ASSERT_OK( 924 ComputeGradientError(scope_, xs, x_shapes, ys, y_shapes, &max_error)); 925 EXPECT_LT(max_error, 1e-3); 926 } 927 928 void RunTest(const Output& x, const Tensor& x_init_value, const Output& y, 929 const TensorShape& y_shape) { 930 TF_ASSERT_OK(scope_.status()); 931 float max_error; 932 TF_ASSERT_OK( 933 ComputeGradientError(scope_, x, x_init_value, y, y_shape, &max_error)); 934 EXPECT_LT(max_error, 1e-3); 935 } 936 937 Scope scope_; 938}; 939 940TEST_F(NaryGradTest, Sum) { 941 TensorShape x_shape({2, 3, 5, 7}); 942 auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); 943 auto y = Sum(scope_, x, {1, -1}); 944 // y's shape is the result of reducing x along axes 1 and -1 (= 3) 945 TensorShape y_shape({2, 5}); 946 RunTest({x}, {x_shape}, {y}, {y_shape}); 947} 948 949TEST_F(NaryGradTest, Mean) { 950 TensorShape x_shape({2, 3, 5, 7}); 951 auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); 952 auto y = Mean(scope_, x, {1, -1}); 953 // y's shape is the result of reducing x along axes 1 and -1 (= 3) 954 TensorShape y_shape({2, 5}); 955 RunTest({x}, {x_shape}, {y}, {y_shape}); 956} 957 958TEST_F(NaryGradTest, AddN) { 959 TensorShape shape({3, 2, 5}); 960 std::vector<Output> xs; 961 xs.push_back(Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape))); 962 xs.push_back(Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape))); 963 xs.push_back(Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape))); 964 auto y = AddN(scope_, xs); 965 RunTest(xs, {shape, shape, shape}, {y}, {shape}); 966} 967 968TEST_F(NaryGradTest, Add) { 969 TensorShape x1_shape({3, 2, 5}); 970 TensorShape x2_shape({2, 5}); 971 auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x1_shape)); 972 auto x2 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x2_shape)); 973 auto y = Add(scope_, x1, x2); 974 RunTest({x1, x2}, {x1_shape, x2_shape}, {y}, {x1_shape}); 975} 976 977TEST_F(NaryGradTest, Sub) { 978 TensorShape x1_shape({3, 2, 5}); 979 TensorShape x2_shape({2, 5}); 980 auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x1_shape)); 981 auto x2 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x2_shape)); 982 auto y = Sub(scope_, x1, x2); 983 RunTest({x1, x2}, {x1_shape, x2_shape}, {y}, {x1_shape}); 984} 985 986TEST_F(NaryGradTest, Mul) { 987 TensorShape x1_shape({3, 2, 5}); 988 TensorShape x2_shape({2, 5}); 989 auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x1_shape)); 990 auto x2 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x2_shape)); 991 auto y = Mul(scope_, x1, x2); 992 RunTest({x1, x2}, {x1_shape, x2_shape}, {y}, {x1_shape}); 993} 994 995TEST_F(NaryGradTest, Div) { 996 TensorShape x_shape({3, 2, 5}); 997 auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); 998 // Test x / (1 + |x|) rather than x_1 / x_2 to avoid triggering large 999 // division errors in the numeric estimator used by the gradient checker. 1000 auto y = Div(scope_, x, Add(scope_, Const<float>(scope_, 1), Abs(scope_, x))); 1001 RunTest({x}, {x_shape}, {y}, {x_shape}); 1002} 1003 1004TEST_F(NaryGradTest, RealDiv) { 1005 TensorShape x_shape({3, 2, 5}); 1006 auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); 1007 // Test x / (1 + |x|) rather than x_1 / x_2 to avoid triggering large 1008 // division errors in the numeric estimator used by the gradient checker. 1009 auto y = 1010 RealDiv(scope_, x, Add(scope_, Const<float>(scope_, 1), Abs(scope_, x))); 1011 RunTest({x}, {x_shape}, {y}, {x_shape}); 1012} 1013 1014TEST_F(NaryGradTest, SquaredDifference) { 1015 TensorShape x1_shape({3, 2, 5}); 1016 TensorShape x2_shape({2, 5}); 1017 auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x1_shape)); 1018 auto x2 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x2_shape)); 1019 auto y = SquaredDifference(scope_, x1, x2); 1020 RunTest({x1, x2}, {x1_shape, x2_shape}, {y}, {x1_shape}); 1021} 1022 1023TEST_F(NaryGradTest, Maximum) { 1024 TensorShape shape({3, 2}); 1025 auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); 1026 auto y = Maximum(scope_, x, Const(scope_, 1.0f)); 1027 // Select values away from 1.0f to avoid instability when computing 1028 // finite differences. 1029 Tensor x_init_value = 1030 test::AsTensor<float>({0.5f, 1.5f, -1.2f, 3.0f, 0.1f, 2.8f}, {3, 2}); 1031 RunTest(x, x_init_value, y, shape); 1032} 1033 1034TEST_F(NaryGradTest, Minimum) { 1035 TensorShape shape({3, 2}); 1036 auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); 1037 auto y = Minimum(scope_, x, Const(scope_, 1.0f)); 1038 // Select values away from 1.0f to avoid instability when computing 1039 // finite differences. 1040 Tensor x_init_value = 1041 test::AsTensor<float>({0.5f, 1.5f, -1.2f, 3.0f, 0.1f, 2.8f}, {3, 2}); 1042 RunTest(x, x_init_value, y, shape); 1043} 1044 1045} // namespace 1046} // namespace tensorflow 1047