math_grad_test.cc revision 50b999a8336d19400ab75aea66fe46eca2f5fe0b
1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include "tensorflow/cc/framework/grad_op_registry.h" 17#include "tensorflow/cc/framework/testutil.h" 18#include "tensorflow/cc/gradients/grad_testutil.h" 19#include "tensorflow/cc/ops/standard_ops.h" 20#include "tensorflow/core/framework/tensor_testutil.h" 21#include "tensorflow/core/lib/core/status_test_util.h" 22#include "tensorflow/core/lib/random/random.h" 23 24namespace tensorflow { 25using namespace ops; // NOLINT(build/namespaces) 26 27namespace { 28 29// TODO(andydavis) Test gradient function against numeric gradients output. 30// TODO(andydavis) As more gradients are added move common test functions 31// to a testutil library. 32 33class CWiseUnaryGradTest : public ::testing::Test { 34 protected: 35 CWiseUnaryGradTest() : scope_(Scope::NewRootScope().WithDevice("/cpu:0")) {} 36 37 enum UnaryOpType { 38 ABS, 39 NEG, 40 INV, 41 SQUARE, 42 SQRT, 43 RSQRT, 44 EXP, 45 EXPM1, 46 LOG, 47 LOG1P, 48 SINH, 49 COSH, 50 TANH, 51 SIGMOID, 52 SIGN, 53 SIN, 54 COS, 55 ASIN, 56 ACOS, 57 TAN, 58 ATAN 59 }; 60 61 template <typename T> 62 void TestCWiseGrad(UnaryOpType op_type, const std::function<T(int)>& x_fn, 63 const std::function<T(const T&)>& dy_fn, 64 const std::function<T(const T&, const T&)>& dx_fn) { 65 DataType dtype = DataTypeToEnum<T>::v(); 66 Tensor x(dtype, {2, 3, 2}); 67 auto x_flat = x.flat<T>(); 68 for (int i = 0; i < x_flat.size(); ++i) { 69 x_flat(i) = x_fn(i); 70 } 71 72 Tensor dy(dtype, {2, 3, 2}); 73 auto dy_flat = dy.flat<T>(); 74 for (int i = 0; i < dy_flat.size(); ++i) { 75 dy_flat(i) = dy_fn(x_flat(i)); 76 } 77 78 Tensor dx(dtype, {2, 3, 2}); 79 auto dx_flat = dx.flat<T>(); 80 for (int i = 0; i < dx_flat.size(); ++i) { 81 dx_flat(i) = dx_fn(x_flat(i), dy_flat(i)); 82 } 83 84 Output y; 85 switch (op_type) { 86 case ABS: 87 y = Abs(scope_, x); 88 break; 89 case NEG: 90 y = Neg(scope_, x); 91 break; 92 case INV: 93 y = Reciprocal(scope_, x); 94 break; 95 case SQUARE: 96 y = Square(scope_, x); 97 break; 98 case SQRT: 99 y = Sqrt(scope_, x); 100 break; 101 case RSQRT: 102 y = Rsqrt(scope_, x); 103 break; 104 case EXP: 105 y = Exp(scope_, x); 106 break; 107 case EXPM1: 108 y = Expm1(scope_, x); 109 break; 110 case LOG: 111 y = Log(scope_, x); 112 break; 113 case LOG1P: 114 y = Log1p(scope_, x); 115 break; 116 case SINH: 117 y = Sinh(scope_, x); 118 break; 119 case COSH: 120 y = Cosh(scope_, x); 121 break; 122 case TANH: 123 y = Tanh(scope_, x); 124 break; 125 case SIGMOID: 126 y = Sigmoid(scope_, x); 127 break; 128 case SIGN: 129 y = Sign(scope_, x); 130 break; 131 case SIN: 132 y = Sin(scope_, x); 133 break; 134 case COS: 135 y = Cos(scope_, x); 136 break; 137 case ASIN: 138 y = Asin(scope_, x); 139 break; 140 case ACOS: 141 y = Acos(scope_, x); 142 break; 143 case TAN: 144 y = Tan(scope_, x); 145 break; 146 case ATAN: 147 y = Atan(scope_, x); 148 break; 149 } 150 151 std::vector<Output> grad_outputs; 152 TF_ASSERT_OK(test::CallGradFunction( 153 scope_, Operation(y.node()), {ops::Const(scope_, dy)}, &grad_outputs)); 154 Tensor output; 155 test::GetTensor(scope_, grad_outputs[0], &output); 156 test::ExpectClose(output, dx); 157 } 158 159 float RV(const std::vector<float>& v) { 160 return v[random::New64() % v.size()]; 161 } 162 163 complex64 CRV(const std::vector<complex64>& v) { 164 return v[random::New64() % v.size()]; 165 } 166 167 complex64 conjugate(const complex64& val) { 168 return complex64(val.real(), -val.imag()); 169 } 170 171 const complex64 one_{1.0, 0}; 172 173 Scope scope_; 174}; 175 176TEST_F(CWiseUnaryGradTest, Abs) { 177 auto x_fn = [this](const int i) { return RV({-1, 0, 1}); }; 178 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 179 auto dx_fn = [this](const float x, const float dy) { return x * dy; }; 180 TestCWiseGrad<float>(ABS, x_fn, dy_fn, dx_fn); 181} 182 183TEST_F(CWiseUnaryGradTest, Neg) { 184 auto x_fn = [this](const int i) { return RV({-1, 0, 1}); }; 185 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 186 auto dx_fn = [this](const float x, const float dy) { return -dy; }; 187 TestCWiseGrad<float>(NEG, x_fn, dy_fn, dx_fn); 188} 189 190TEST_F(CWiseUnaryGradTest, Reciprocal) { 191 auto x_fn = [this](const int i) { return RV({-1, 1, -2, 2, -3, 3, -4, 4}); }; 192 auto dy_fn = [this](const float x) { return RV({0, -2, 2, -3, 3, -4, 4}); }; 193 auto dx_fn = [this](const float x, const float dy) { 194 return -(1 / (x * x)) * dy; 195 }; 196 TestCWiseGrad<float>(INV, x_fn, dy_fn, dx_fn); 197} 198 199TEST_F(CWiseUnaryGradTest, Reciprocal_Complex) { 200 auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); }; 201 auto dy_fn = [this](const complex64 x) { 202 return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); 203 }; 204 auto dx_fn = [this](const complex64 x, const complex64 dy) { 205 return -conjugate(one_ / (x * x)) * dy; 206 }; 207 TestCWiseGrad<complex64>(INV, x_fn, dy_fn, dx_fn); 208} 209 210TEST_F(CWiseUnaryGradTest, Square) { 211 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 212 auto dy_fn = [this](const float x) { return RV({0, -7, 7, -8, 8, -9, 9}); }; 213 auto dx_fn = [this](const float x, const float dy) { return 2 * x * dy; }; 214 TestCWiseGrad<float>(SQUARE, x_fn, dy_fn, dx_fn); 215} 216 217TEST_F(CWiseUnaryGradTest, Square_Complex) { 218 auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); }; 219 auto dy_fn = [this](const complex64& x) { 220 return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); 221 }; 222 auto dx_fn = [this](const complex64& x, const complex64& dy) { 223 return conjugate(complex64(2, 0) * x) * dy; 224 }; 225 TestCWiseGrad<complex64>(SQUARE, x_fn, dy_fn, dx_fn); 226} 227 228TEST_F(CWiseUnaryGradTest, Sqrt) { 229 auto x_fn = [this](const int i) { return RV({0, 1, 2, 3, 4, 5, 6, 7}); }; 230 auto dy_fn = [this](const float x) { return x + RV({8, 9, 10, 11, 12, 13}); }; 231 auto dx_fn = [this](const float x, const float dy) { 232 return dy * 0.5 * (1.0 / std::sqrt(x)); 233 }; 234 TestCWiseGrad<float>(SQRT, x_fn, dy_fn, dx_fn); 235} 236 237TEST_F(CWiseUnaryGradTest, Sqrt_Complex) { 238 auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); }; 239 auto dy_fn = [this](const complex64& x) { 240 return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); 241 }; 242 auto dx_fn = [this](const complex64& x, const complex64& dy) { 243 return conjugate(complex64(0.5, 0) / std::sqrt(x)) * dy; 244 }; 245 TestCWiseGrad<complex64>(SQRT, x_fn, dy_fn, dx_fn); 246} 247 248TEST_F(CWiseUnaryGradTest, Rsqrt) { 249 auto x_fn = [this](const int i) { return RV({1, 2, 3, 4, 5, 6, 7, 8}); }; 250 auto dy_fn = [this](const float x) { return x + RV({8, 9, 10, 11, 12, 13}); }; 251 auto dx_fn = [this](const float x, const float dy) { 252 return dy * -0.5 * (1 / std::sqrt(x)) * (1 / x); 253 }; 254 TestCWiseGrad<float>(RSQRT, x_fn, dy_fn, dx_fn); 255} 256 257TEST_F(CWiseUnaryGradTest, Rsqrt_Complex) { 258 auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); }; 259 auto dy_fn = [this](const complex64& x) { 260 return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); 261 }; 262 auto dx_fn = [this](const complex64& x, const complex64& dy) { 263 return conjugate(complex64(-0.5, 0) / std::sqrt(x) / x) * dy; 264 }; 265 TestCWiseGrad<complex64>(RSQRT, x_fn, dy_fn, dx_fn); 266} 267 268TEST_F(CWiseUnaryGradTest, Exp) { 269 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 270 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 271 auto dx_fn = [this](const float x, const float dy) { 272 return dy * std::exp(x); 273 }; 274 TestCWiseGrad<float>(EXP, x_fn, dy_fn, dx_fn); 275} 276 277TEST_F(CWiseUnaryGradTest, Exp_Complex) { 278 auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); }; 279 auto dy_fn = [this](const complex64& x) { 280 return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); 281 }; 282 auto dx_fn = [this](const complex64& x, const complex64& dy) { 283 return dy * conjugate(std::exp(x)); 284 }; 285 TestCWiseGrad<complex64>(EXP, x_fn, dy_fn, dx_fn); 286} 287 288TEST_F(CWiseUnaryGradTest, Expm1) { 289 auto x_fn = [this](const int i) { return RV({0, -1, 1e-6, 1, -2, 3, 100}); }; 290 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 291 auto dx_fn = [this](const float x, const float dy) { 292 return dy * std::exp(x); 293 }; 294 TestCWiseGrad<float>(EXPM1, x_fn, dy_fn, dx_fn); 295} 296 297TEST_F(CWiseUnaryGradTest, Expm1_Complex) { 298 auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); }; 299 auto dy_fn = [this](const complex64& x) { 300 return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); 301 }; 302 auto dx_fn = [this](const complex64& x, const complex64& dy) { 303 return dy * conjugate(std::exp(x)); 304 }; 305 TestCWiseGrad<complex64>(EXPM1, x_fn, dy_fn, dx_fn); 306} 307 308TEST_F(CWiseUnaryGradTest, Log) { 309 auto x_fn = [this](const int i) { return RV({-1, 1, -2, 2, -3, 3, -4, 4}); }; 310 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 311 auto dx_fn = [this](const float x, const float dy) { return dy * (1.0 / x); }; 312 TestCWiseGrad<float>(LOG, x_fn, dy_fn, dx_fn); 313} 314 315TEST_F(CWiseUnaryGradTest, Log_Complex) { 316 auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); }; 317 auto dy_fn = [this](const complex64& x) { 318 return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); 319 }; 320 auto dx_fn = [this](const complex64& x, const complex64& dy) { 321 return dy * conjugate(one_ / x); 322 }; 323 TestCWiseGrad<complex64>(LOG, x_fn, dy_fn, dx_fn); 324} 325 326TEST_F(CWiseUnaryGradTest, Log1p) { 327 auto x_fn = [this](const int i) { return RV({0, 1e-6, 1, 2, 3, 4, 100}); }; 328 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 329 auto dx_fn = [this](const float x, const float dy) { 330 return dy * (1.0 / (1.0 + x)); 331 }; 332 TestCWiseGrad<float>(LOG1P, x_fn, dy_fn, dx_fn); 333} 334 335TEST_F(CWiseUnaryGradTest, Log1p_Complex) { 336 auto x_fn = [this](const int i) { 337 return CRV({{0, 0}, {1e-6, 0}, {2, -1}, {1, 2}, {3, 4}}); 338 }; 339 auto dy_fn = [this](const complex64& x) { 340 return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); 341 }; 342 auto dx_fn = [this](const complex64& x, const complex64& dy) { 343 return dy / (one_ + conjugate(x)); 344 }; 345 TestCWiseGrad<complex64>(LOG1P, x_fn, dy_fn, dx_fn); 346} 347 348TEST_F(CWiseUnaryGradTest, Sinh) { 349 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 350 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 351 auto dx_fn = [this](const float x, const float dy) { 352 return dy * std::cosh(x); 353 }; 354 TestCWiseGrad<float>(SINH, x_fn, dy_fn, dx_fn); 355} 356 357TEST_F(CWiseUnaryGradTest, Sinh_Complex) { 358 auto x_fn = [this](const int i) { 359 return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); 360 }; 361 auto dy_fn = [this](const complex64& x) { 362 return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); 363 }; 364 auto dx_fn = [this](const complex64& x, const complex64& dy) { 365 return dy * conjugate(std::cosh(x)); 366 }; 367 TestCWiseGrad<complex64>(SINH, x_fn, dy_fn, dx_fn); 368} 369 370TEST_F(CWiseUnaryGradTest, Cosh) { 371 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 372 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 373 auto dx_fn = [this](const float x, const float dy) { 374 return dy * std::sinh(x); 375 }; 376 TestCWiseGrad<float>(COSH, x_fn, dy_fn, dx_fn); 377} 378 379TEST_F(CWiseUnaryGradTest, Cosh_Complex) { 380 auto x_fn = [this](const int i) { 381 return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); 382 }; 383 auto dy_fn = [this](const complex64& x) { 384 return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); 385 }; 386 auto dx_fn = [this](const complex64& x, const complex64& dy) { 387 return dy * conjugate(std::sinh(x)); 388 }; 389 TestCWiseGrad<complex64>(COSH, x_fn, dy_fn, dx_fn); 390} 391 392TEST_F(CWiseUnaryGradTest, Tanh) { 393 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 394 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 395 auto dx_fn = [this](const float x, const float dy) { 396 const float y = std::tanh(x); 397 return dy * (1.0 - y * y); 398 }; 399 TestCWiseGrad<float>(TANH, x_fn, dy_fn, dx_fn); 400} 401 402TEST_F(CWiseUnaryGradTest, Tanh_Complex) { 403 auto x_fn = [this](const int i) { 404 return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); 405 }; 406 auto dy_fn = [this](const complex64& x) { 407 return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); 408 }; 409 auto dx_fn = [this](const complex64& x, const complex64& dy) { 410 const complex64 y = std::tanh(x); 411 return dy * conjugate((one_ - y * y)); 412 }; 413 TestCWiseGrad<complex64>(TANH, x_fn, dy_fn, dx_fn); 414} 415 416TEST_F(CWiseUnaryGradTest, Sigmoid) { 417 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 418 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 419 auto dx_fn = [this](const float x, const float dy) { 420 const float y = 1.0 / (1.0 + std::exp(-x)); 421 return dy * y * (1.0 - y); 422 }; 423 TestCWiseGrad<float>(SIGMOID, x_fn, dy_fn, dx_fn); 424} 425 426TEST_F(CWiseUnaryGradTest, Sigmoid_Complex) { 427 auto x_fn = [this](const int i) { 428 return CRV({{1, 0}, {0, 0}, {2, -1}, {1, 2}, {3, 4}}); 429 }; 430 auto dy_fn = [this](const complex64& x) { 431 return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); 432 }; 433 auto dx_fn = [this](const complex64& x, const complex64& dy) { 434 const complex64 y = one_ / (one_ + std::exp(-x)); 435 return dy * conjugate(y * (one_ - y)); 436 }; 437 TestCWiseGrad<complex64>(SIGMOID, x_fn, dy_fn, dx_fn); 438} 439 440TEST_F(CWiseUnaryGradTest, Sign) { 441 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 442 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 443 auto dx_fn = [this](const float x, const float dy) { return 0.0; }; 444 TestCWiseGrad<float>(SIGN, x_fn, dy_fn, dx_fn); 445} 446 447TEST_F(CWiseUnaryGradTest, Sin) { 448 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 449 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 450 auto dx_fn = [this](const float x, const float dy) { 451 return dy * std::cos(x); 452 }; 453 TestCWiseGrad<float>(SIN, x_fn, dy_fn, dx_fn); 454} 455 456TEST_F(CWiseUnaryGradTest, Sin_Complex) { 457 auto x_fn = [this](const int i) { 458 return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); 459 }; 460 auto dy_fn = [this](const complex64& x) { 461 return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); 462 }; 463 auto dx_fn = [this](const complex64& x, const complex64& dy) { 464 return dy * conjugate(std::cos(x)); 465 }; 466 TestCWiseGrad<complex64>(SIN, x_fn, dy_fn, dx_fn); 467} 468 469TEST_F(CWiseUnaryGradTest, Cos) { 470 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 471 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 472 auto dx_fn = [this](const float x, const float dy) { 473 return dy * -1.0 * std::sin(x); 474 }; 475 TestCWiseGrad<float>(COS, x_fn, dy_fn, dx_fn); 476} 477 478TEST_F(CWiseUnaryGradTest, Cos_Complex) { 479 auto x_fn = [this](const int i) { 480 return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); 481 }; 482 auto dy_fn = [this](const complex64& x) { 483 return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); 484 }; 485 auto dx_fn = [this](const complex64& x, const complex64& dy) { 486 return dy * conjugate(-std::sin(x)); 487 }; 488 TestCWiseGrad<complex64>(COS, x_fn, dy_fn, dx_fn); 489} 490 491TEST_F(CWiseUnaryGradTest, Asin) { 492 auto x_fn = [this](const int i) { return RV({0, -0.5, 0.5, -1, 1}); }; 493 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 494 auto dx_fn = [this](const float x, const float dy) { 495 return dy * (1.0 / std::sqrt(1.0 - x * x)); 496 }; 497 TestCWiseGrad<float>(ASIN, x_fn, dy_fn, dx_fn); 498} 499 500TEST_F(CWiseUnaryGradTest, Asin_Complex) { 501 auto x_fn = [this](const int i) { 502 return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); 503 }; 504 auto dy_fn = [this](const complex64& x) { 505 return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); 506 }; 507 auto dx_fn = [this](const complex64& x, const complex64& dy) { 508 return dy / conjugate(std::sqrt(one_ - x * x)); 509 }; 510 // TODO(kbsriram) 511 // Enable test when the asin kernel supports complex numbers 512 if (false) { 513 TestCWiseGrad<complex64>(ASIN, x_fn, dy_fn, dx_fn); 514 } 515} 516 517TEST_F(CWiseUnaryGradTest, Acos) { 518 auto x_fn = [this](const int i) { return RV({0, -0.5, 0.5, -1, 1}); }; 519 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 520 auto dx_fn = [this](const float x, const float dy) { 521 return dy * (-1.0 / std::sqrt(1.0 - x * x)); 522 }; 523 TestCWiseGrad<float>(ACOS, x_fn, dy_fn, dx_fn); 524} 525 526TEST_F(CWiseUnaryGradTest, Acos_Complex) { 527 auto x_fn = [this](const int i) { 528 return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); 529 }; 530 auto dy_fn = [this](const complex64& x) { 531 return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); 532 }; 533 auto dx_fn = [this](const complex64& x, const complex64& dy) { 534 return dy / -conjugate(std::sqrt(one_ - x * x)); 535 }; 536 // TODO(kbsriram) 537 // Add test when the acos kernel supports complex numbers 538 if (false) { 539 TestCWiseGrad<complex64>(ACOS, x_fn, dy_fn, dx_fn); 540 } 541} 542 543TEST_F(CWiseUnaryGradTest, Tan) { 544 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 545 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 546 auto dx_fn = [this](const float x, const float dy) { 547 const float cosx = std::cos(x); 548 return dy * (1 / (cosx * cosx)); 549 }; 550 TestCWiseGrad<float>(TAN, x_fn, dy_fn, dx_fn); 551} 552 553TEST_F(CWiseUnaryGradTest, Tan_Complex) { 554 auto x_fn = [this](const int i) { 555 return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); 556 }; 557 auto dy_fn = [this](const complex64& x) { 558 return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); 559 }; 560 auto dx_fn = [this](const complex64& x, const complex64& dy) { 561 const complex64 cosx = std::cos(x); 562 return dy / conjugate(cosx * cosx); 563 }; 564 // TODO(kbsriram) 565 // Enable when tan kernel supports complex inputs 566 if (false) { 567 TestCWiseGrad<complex64>(TAN, x_fn, dy_fn, dx_fn); 568 } 569} 570 571TEST_F(CWiseUnaryGradTest, Atan) { 572 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 573 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 574 auto dx_fn = [this](const float x, const float dy) { 575 return dy * (1 / (1 + x * x)); 576 }; 577 TestCWiseGrad<float>(ATAN, x_fn, dy_fn, dx_fn); 578} 579 580TEST_F(CWiseUnaryGradTest, Atan_Complex) { 581 auto x_fn = [this](const int i) { 582 return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); 583 }; 584 auto dy_fn = [this](const complex64& x) { 585 return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); 586 }; 587 auto dx_fn = [this](const complex64& x, const complex64& dy) { 588 return dy / (one_ + x * x); 589 }; 590 // TODO(kbsriram) 591 // Add test when the atan kernel supports complex numbers 592 if (false) { 593 TestCWiseGrad<complex64>(ATAN, x_fn, dy_fn, dx_fn); 594 } 595} 596 597class CWiseUnaryComplexGradTest : public ::testing::Test { 598 protected: 599 CWiseUnaryComplexGradTest() 600 : scope_(Scope::NewRootScope().WithDevice("/cpu:0")) {} 601 602 enum UnaryOpType { REAL, IMAG, CONJ }; 603 604 void TestCWiseGradComplex(UnaryOpType op_type, const Tensor& x, 605 const Tensor& dy, const Tensor& dx_expected) { 606 Output y; 607 switch (op_type) { 608 case REAL: 609 y = Real(scope_, x); 610 break; 611 case IMAG: 612 y = Imag(scope_, x); 613 break; 614 case CONJ: 615 y = Conj(scope_, x); 616 break; 617 } 618 619 std::vector<Output> grad_outputs; 620 TF_ASSERT_OK(test::CallGradFunction( 621 scope_, Operation(y.node()), {ops::Const(scope_, dy)}, &grad_outputs)); 622 Tensor dx; 623 test::GetTensor(scope_, grad_outputs[0], &dx); 624 test::ExpectClose(dx, dx_expected); 625 } 626 627 Scope scope_; 628}; 629 630TEST_F(CWiseUnaryComplexGradTest, Real) { 631 Tensor x = test::AsTensor<complex64>( 632 {{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3}); 633 Tensor dy = test::AsTensor<float>({11, -12, 13, -14, 15, -16}, {2, 3}); 634 Tensor dx_expected = test::AsTensor<complex64>( 635 {{11, 0}, {-12, 0}, {13, 0}, {-14, 0}, {15, 0}, {-16, 0}}, {2, 3}); 636 TestCWiseGradComplex(REAL, x, dy, dx_expected); 637} 638 639TEST_F(CWiseUnaryComplexGradTest, Imag) { 640 Tensor x = test::AsTensor<complex64>( 641 {{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3}); 642 Tensor dy = test::AsTensor<float>({11, -12, 13, -14, 15, -16}, {2, 3}); 643 Tensor dx_expected = test::AsTensor<complex64>( 644 {{0, 11}, {0, -12}, {0, 13}, {0, -14}, {0, 15}, {0, -16}}, {2, 3}); 645 TestCWiseGradComplex(IMAG, x, dy, dx_expected); 646} 647 648TEST_F(CWiseUnaryComplexGradTest, Conj) { 649 Tensor x = test::AsTensor<complex64>( 650 {{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3}); 651 Tensor dy = test::AsTensor<complex64>( 652 {{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3}); 653 Tensor dx_expected = test::AsTensor<complex64>( 654 {{1, 1}, {-2, -2}, {3, 3}, {-4, -4}, {8, 8}, {-9, -9}}, {2, 3}); 655 TestCWiseGradComplex(CONJ, x, dy, dx_expected); 656} 657 658class MathGradTest : public ::testing::Test { 659 protected: 660 MathGradTest() : root_(Scope::NewRootScope().WithDevice("/cpu:0")) {} 661 662 void TestMatMulGrad(const bool is_batch, const bool t_x, const bool t_y) { 663 // Generate random test data. 664 std::vector<Tensor> data; 665 RandMatMulGradData(is_batch, t_x, t_y, &data); 666 auto x = Const(root_, data[0]); 667 auto y = Const(root_, data[1]); 668 auto dz = Const(root_, data[2]); 669 670 std::vector<Tensor> grad_outputs; 671 ComputeMatMulGrad(is_batch, x, t_x, y, t_y, dz, &grad_outputs); 672 673 if (!t_x && !t_y) { 674 test::ExpectClose(grad_outputs[0], 675 ComputeMatMul(is_batch, dz, false, y, true)); 676 test::ExpectClose(grad_outputs[1], 677 ComputeMatMul(is_batch, x, true, dz, false)); 678 } else if (t_x && !t_y) { 679 test::ExpectClose(grad_outputs[0], 680 ComputeMatMul(is_batch, y, false, dz, true)); 681 test::ExpectClose(grad_outputs[1], 682 ComputeMatMul(is_batch, x, false, dz, false)); 683 } else if (!t_x && t_y) { 684 test::ExpectClose(grad_outputs[0], 685 ComputeMatMul(is_batch, dz, false, y, false)); 686 test::ExpectClose(grad_outputs[1], 687 ComputeMatMul(is_batch, dz, true, x, false)); 688 } else { 689 test::ExpectClose(grad_outputs[0], 690 ComputeMatMul(is_batch, y, true, dz, true)); 691 test::ExpectClose(grad_outputs[1], 692 ComputeMatMul(is_batch, dz, true, x, true)); 693 } 694 } 695 696 void ComputeMatMulGrad(const bool is_batch, const Output& x, const bool t_x, 697 const Output& y, const bool t_y, const Output& dz, 698 std::vector<Tensor>* out) { 699 // Compute forward MatMul: z = MatMul(x, y). 700 Output z; 701 if (is_batch) { 702 z = BatchMatMul(root_, x, y, BatchMatMul::AdjX(t_x).AdjY(t_y)); 703 } else { 704 z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y)); 705 } 706 TF_ASSERT_OK(root_.status()); 707 CHECK_NOTNULL(z.node()); 708 std::vector<Output> grad_outputs; 709 // Call MatMulGrad which populates 'grad_outputs'. 710 TF_ASSERT_OK(test::CallGradFunction(root_, Operation(z.node()), {dz}, 711 &grad_outputs)); 712 ASSERT_EQ(2, grad_outputs.size()); 713 // Run graph and return MatMul gradient tensors for 'dx' and 'dy' in 'out'. 714 test::GetTensors(root_, {grad_outputs[0], grad_outputs[1]}, out); 715 } 716 717 Tensor ComputeMatMul(const bool is_batch, const Output& x, const bool t_x, 718 const Output& y, const bool t_y) { 719 Output z; 720 if (is_batch) { 721 z = BatchMatMul(root_, x, y, BatchMatMul::AdjX(t_x).AdjY(t_y)); 722 } else { 723 z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y)); 724 } 725 TF_EXPECT_OK(root_.status()); 726 Tensor out; 727 test::GetTensor(root_, z, &out); 728 return out; 729 } 730 731 void RandMatMulGradData(const bool is_batch, const bool tx, const bool ty, 732 std::vector<Tensor>* data) { 733 // Choose a random batch size in [1, 4] 734 const int b = 1 + (random::New64() % 4); 735 // z = MatMul(x, y) 736 const int m = Rand(); 737 const int k = Rand(); 738 const int n = Rand(); 739 740 TensorShape x_shape; 741 if (is_batch) { 742 // x.shape = [b, m, k] 743 x_shape = tx ? TensorShape({b, k, m}) : TensorShape({b, m, k}); 744 } else { 745 // x.shape = [m, k] 746 x_shape = tx ? TensorShape({k, m}) : TensorShape({m, k}); 747 } 748 data->emplace_back(DT_FLOAT, x_shape); 749 RandTensor(&data->back()); 750 751 TensorShape y_shape; 752 if (is_batch) { 753 // y.shape = [b, k, n] 754 y_shape = ty ? TensorShape({b, n, k}) : TensorShape({b, k, n}); 755 } else { 756 // y.shape = [k, n] 757 y_shape = ty ? TensorShape({n, k}) : TensorShape({k, n}); 758 } 759 data->emplace_back(DT_FLOAT, y_shape); 760 RandTensor(&data->back()); 761 762 TensorShape z_shape; 763 if (is_batch) { 764 // z.shape = [b, m, n] 765 z_shape = TensorShape({b, m, n}); 766 } else { 767 // z.shape = [m, n] 768 z_shape = TensorShape({m, n}); 769 } 770 data->emplace_back(DT_FLOAT, z_shape); 771 RandTensor(&data->back()); 772 } 773 774 void RandTensor(Tensor* t) { 775 test::FillFn<float>( 776 t, [this](const int i) { return static_cast<float>(Rand()); }); 777 } 778 779 int Rand() { return 1 + (random::New64() % 10); } 780 781 Scope root_; 782}; 783 784TEST_F(MathGradTest, MatMulGrad_NoTranspose) { 785 TestMatMulGrad(false, false, false); 786} 787 788TEST_F(MathGradTest, MatMulGrad_TransposeX) { 789 TestMatMulGrad(false, true, false); 790} 791 792TEST_F(MathGradTest, MatMulGrad_TransposeY) { 793 TestMatMulGrad(false, false, true); 794} 795 796TEST_F(MathGradTest, MatMulGrad_TransposeX_TransposeY) { 797 TestMatMulGrad(false, true, true); 798} 799 800TEST_F(MathGradTest, BatchMatMulGrad_NoTranspose) { 801 TestMatMulGrad(true, false, false); 802} 803 804TEST_F(MathGradTest, BatchMatMulGrad_TransposeX) { 805 TestMatMulGrad(true, true, false); 806} 807 808TEST_F(MathGradTest, BatchMatMulGrad_TransposeY) { 809 TestMatMulGrad(true, false, true); 810} 811 812TEST_F(MathGradTest, BatchMatMulGrad_TransposeX_TransposeY) { 813 TestMatMulGrad(true, true, true); 814} 815 816} // namespace 817} // namespace tensorflow 818