math_grad_test.cc revision 1fa73c53ab95693f070ce70e6be0c644d83c163a
1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include "tensorflow/cc/framework/grad_op_registry.h" 17#include "tensorflow/cc/framework/testutil.h" 18#include "tensorflow/cc/gradients/grad_testutil.h" 19#include "tensorflow/cc/ops/standard_ops.h" 20#include "tensorflow/core/framework/tensor_testutil.h" 21#include "tensorflow/core/lib/core/status_test_util.h" 22#include "tensorflow/core/lib/random/random.h" 23 24namespace tensorflow { 25using namespace ops; // NOLINT(build/namespaces) 26 27namespace { 28 29// TODO(andydavis) Test gradient function against numeric gradients output. 30// TODO(andydavis) As more gradients are added move common test functions 31// to a testutil library. 32 33class CWiseUnaryGradTest : public ::testing::Test { 34 protected: 35 CWiseUnaryGradTest() : scope_(Scope::NewRootScope().WithDevice("/cpu:0")) {} 36 37 enum UnaryOpType { 38 ABS, 39 NEG, 40 INV, 41 SQUARE, 42 SQRT, 43 RSQRT, 44 EXP, 45 EXPM1, 46 LOG, 47 LOG1P, 48 TANH, 49 SIGMOID, 50 SIGN, 51 SIN, 52 COS, 53 ASIN, 54 ACOS, 55 TAN, 56 ATAN 57 }; 58 59 template <typename T> 60 void TestCWiseGrad(UnaryOpType op_type, const std::function<T(int)>& x_fn, 61 const std::function<T(const T&)>& dy_fn, 62 const std::function<T(const T&, const T&)>& dx_fn) { 63 DataType dtype = DataTypeToEnum<T>::v(); 64 Tensor x(dtype, {2, 3, 2}); 65 auto x_flat = x.flat<T>(); 66 for (int i = 0; i < x_flat.size(); ++i) { 67 x_flat(i) = x_fn(i); 68 } 69 70 Tensor dy(dtype, {2, 3, 2}); 71 auto dy_flat = dy.flat<T>(); 72 for (int i = 0; i < dy_flat.size(); ++i) { 73 dy_flat(i) = dy_fn(x_flat(i)); 74 } 75 76 Tensor dx(dtype, {2, 3, 2}); 77 auto dx_flat = dx.flat<T>(); 78 for (int i = 0; i < dx_flat.size(); ++i) { 79 dx_flat(i) = dx_fn(x_flat(i), dy_flat(i)); 80 } 81 82 Output y; 83 switch (op_type) { 84 case ABS: 85 y = Abs(scope_, x); 86 break; 87 case NEG: 88 y = Neg(scope_, x); 89 break; 90 case INV: 91 y = Reciprocal(scope_, x); 92 break; 93 case SQUARE: 94 y = Square(scope_, x); 95 break; 96 case SQRT: 97 y = Sqrt(scope_, x); 98 break; 99 case RSQRT: 100 y = Rsqrt(scope_, x); 101 break; 102 case EXP: 103 y = Exp(scope_, x); 104 break; 105 case EXPM1: 106 y = Expm1(scope_, x); 107 break; 108 case LOG: 109 y = Log(scope_, x); 110 break; 111 case LOG1P: 112 y = Log1p(scope_, x); 113 break; 114 case TANH: 115 y = Tanh(scope_, x); 116 break; 117 case SIGMOID: 118 y = Sigmoid(scope_, x); 119 break; 120 case SIGN: 121 y = Sign(scope_, x); 122 break; 123 case SIN: 124 y = Sin(scope_, x); 125 break; 126 case COS: 127 y = Cos(scope_, x); 128 break; 129 case ASIN: 130 y = Asin(scope_, x); 131 break; 132 case ACOS: 133 y = Acos(scope_, x); 134 break; 135 case TAN: 136 y = Tan(scope_, x); 137 break; 138 case ATAN: 139 y = Atan(scope_, x); 140 break; 141 } 142 143 std::vector<Output> grad_outputs; 144 TF_ASSERT_OK(test::CallGradFunction( 145 scope_, Operation(y.node()), {ops::Const(scope_, dy)}, &grad_outputs)); 146 Tensor output; 147 test::GetTensor(scope_, grad_outputs[0], &output); 148 test::ExpectClose(output, dx); 149 } 150 151 float RV(const std::vector<float>& v) { 152 return v[random::New64() % v.size()]; 153 } 154 155 complex64 CRV(const std::vector<complex64>& v) { 156 return v[random::New64() % v.size()]; 157 } 158 159 complex64 conjugate(const complex64& val) { 160 return complex64(val.real(), -val.imag()); 161 } 162 163 const complex64 one_{1.0, 0}; 164 165 Scope scope_; 166}; 167 168TEST_F(CWiseUnaryGradTest, Abs) { 169 auto x_fn = [this](const int i) { return RV({-1, 0, 1}); }; 170 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 171 auto dx_fn = [this](const float x, const float dy) { return x * dy; }; 172 TestCWiseGrad<float>(ABS, x_fn, dy_fn, dx_fn); 173} 174 175TEST_F(CWiseUnaryGradTest, Neg) { 176 auto x_fn = [this](const int i) { return RV({-1, 0, 1}); }; 177 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 178 auto dx_fn = [this](const float x, const float dy) { return -dy; }; 179 TestCWiseGrad<float>(NEG, x_fn, dy_fn, dx_fn); 180} 181 182TEST_F(CWiseUnaryGradTest, Reciprocal) { 183 auto x_fn = [this](const int i) { return RV({-1, 1, -2, 2, -3, 3, -4, 4}); }; 184 auto dy_fn = [this](const float x) { return RV({0, -2, 2, -3, 3, -4, 4}); }; 185 auto dx_fn = [this](const float x, const float dy) { 186 return -(1 / (x * x)) * dy; 187 }; 188 TestCWiseGrad<float>(INV, x_fn, dy_fn, dx_fn); 189} 190 191TEST_F(CWiseUnaryGradTest, Reciprocal_Complex) { 192 auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); }; 193 auto dy_fn = [this](const complex64 x) { 194 return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); 195 }; 196 auto dx_fn = [this](const complex64 x, const complex64 dy) { 197 return -conjugate(one_ / (x * x)) * dy; 198 }; 199 TestCWiseGrad<complex64>(INV, x_fn, dy_fn, dx_fn); 200} 201 202TEST_F(CWiseUnaryGradTest, Square) { 203 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 204 auto dy_fn = [this](const float x) { return RV({0, -7, 7, -8, 8, -9, 9}); }; 205 auto dx_fn = [this](const float x, const float dy) { return 2 * x * dy; }; 206 TestCWiseGrad<float>(SQUARE, x_fn, dy_fn, dx_fn); 207} 208 209TEST_F(CWiseUnaryGradTest, Square_Complex) { 210 auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); }; 211 auto dy_fn = [this](const complex64& x) { 212 return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); 213 }; 214 auto dx_fn = [this](const complex64& x, const complex64& dy) { 215 return conjugate(complex64(2, 0) * x) * dy; 216 }; 217 TestCWiseGrad<complex64>(SQUARE, x_fn, dy_fn, dx_fn); 218} 219 220TEST_F(CWiseUnaryGradTest, Sqrt) { 221 auto x_fn = [this](const int i) { return RV({0, 1, 2, 3, 4, 5, 6, 7}); }; 222 auto dy_fn = [this](const float x) { return x + RV({8, 9, 10, 11, 12, 13}); }; 223 auto dx_fn = [this](const float x, const float dy) { 224 return dy * 0.5 * (1.0 / std::sqrt(x)); 225 }; 226 TestCWiseGrad<float>(SQRT, x_fn, dy_fn, dx_fn); 227} 228 229TEST_F(CWiseUnaryGradTest, Sqrt_Complex) { 230 auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); }; 231 auto dy_fn = [this](const complex64& x) { 232 return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); 233 }; 234 auto dx_fn = [this](const complex64& x, const complex64& dy) { 235 return conjugate(complex64(0.5, 0) / std::sqrt(x)) * dy; 236 }; 237 TestCWiseGrad<complex64>(SQRT, x_fn, dy_fn, dx_fn); 238} 239 240TEST_F(CWiseUnaryGradTest, Rsqrt) { 241 auto x_fn = [this](const int i) { return RV({1, 2, 3, 4, 5, 6, 7, 8}); }; 242 auto dy_fn = [this](const float x) { return x + RV({8, 9, 10, 11, 12, 13}); }; 243 auto dx_fn = [this](const float x, const float dy) { 244 return dy * -0.5 * (1 / std::sqrt(x)) * (1 / x); 245 }; 246 TestCWiseGrad<float>(RSQRT, x_fn, dy_fn, dx_fn); 247} 248 249TEST_F(CWiseUnaryGradTest, Rsqrt_Complex) { 250 auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); }; 251 auto dy_fn = [this](const complex64& x) { 252 return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); 253 }; 254 auto dx_fn = [this](const complex64& x, const complex64& dy) { 255 return conjugate(complex64(-0.5, 0) / std::sqrt(x) / x) * dy; 256 }; 257 TestCWiseGrad<complex64>(RSQRT, x_fn, dy_fn, dx_fn); 258} 259 260TEST_F(CWiseUnaryGradTest, Exp) { 261 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 262 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 263 auto dx_fn = [this](const float x, const float dy) { 264 return dy * std::exp(x); 265 }; 266 TestCWiseGrad<float>(EXP, x_fn, dy_fn, dx_fn); 267} 268 269TEST_F(CWiseUnaryGradTest, Exp_Complex) { 270 auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); }; 271 auto dy_fn = [this](const complex64& x) { 272 return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); 273 }; 274 auto dx_fn = [this](const complex64& x, const complex64& dy) { 275 return dy * conjugate(std::exp(x)); 276 }; 277 TestCWiseGrad<complex64>(EXP, x_fn, dy_fn, dx_fn); 278} 279 280TEST_F(CWiseUnaryGradTest, Expm1) { 281 auto x_fn = [this](const int i) { return RV({0, -1, 1e-6, 1, -2, 3, 100}); }; 282 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 283 auto dx_fn = [this](const float x, const float dy) { 284 return dy * std::exp(x); 285 }; 286 TestCWiseGrad<float>(EXPM1, x_fn, dy_fn, dx_fn); 287} 288 289TEST_F(CWiseUnaryGradTest, Expm1_Complex) { 290 auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); }; 291 auto dy_fn = [this](const complex64& x) { 292 return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); 293 }; 294 auto dx_fn = [this](const complex64& x, const complex64& dy) { 295 return dy * conjugate(std::exp(x)); 296 }; 297 TestCWiseGrad<complex64>(EXPM1, x_fn, dy_fn, dx_fn); 298} 299 300TEST_F(CWiseUnaryGradTest, Log) { 301 auto x_fn = [this](const int i) { return RV({-1, 1, -2, 2, -3, 3, -4, 4}); }; 302 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 303 auto dx_fn = [this](const float x, const float dy) { return dy * (1.0 / x); }; 304 TestCWiseGrad<float>(LOG, x_fn, dy_fn, dx_fn); 305} 306 307TEST_F(CWiseUnaryGradTest, Log_Complex) { 308 auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); }; 309 auto dy_fn = [this](const complex64& x) { 310 return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); 311 }; 312 auto dx_fn = [this](const complex64& x, const complex64& dy) { 313 return dy * conjugate(one_ / x); 314 }; 315 TestCWiseGrad<complex64>(LOG, x_fn, dy_fn, dx_fn); 316} 317 318TEST_F(CWiseUnaryGradTest, Log1p) { 319 auto x_fn = [this](const int i) { return RV({0, 1e-6, 1, 2, 3, 4, 100}); }; 320 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 321 auto dx_fn = [this](const float x, const float dy) { 322 return dy * (1.0 / (1.0 + x)); 323 }; 324 TestCWiseGrad<float>(LOG1P, x_fn, dy_fn, dx_fn); 325} 326 327TEST_F(CWiseUnaryGradTest, Log1p_Complex) { 328 auto x_fn = [this](const int i) { 329 return CRV({{0, 0}, {1e-6, 0}, {2, -1}, {1, 2}, {3, 4}}); 330 }; 331 auto dy_fn = [this](const complex64& x) { 332 return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); 333 }; 334 auto dx_fn = [this](const complex64& x, const complex64& dy) { 335 return dy / (one_ + conjugate(x)); 336 }; 337 TestCWiseGrad<complex64>(LOG1P, x_fn, dy_fn, dx_fn); 338} 339 340TEST_F(CWiseUnaryGradTest, Tanh) { 341 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 342 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 343 auto dx_fn = [this](const float x, const float dy) { 344 const float y = std::tanh(x); 345 return dy * (1.0 - y * y); 346 }; 347 TestCWiseGrad<float>(TANH, x_fn, dy_fn, dx_fn); 348} 349 350TEST_F(CWiseUnaryGradTest, Tanh_Complex) { 351 auto x_fn = [this](const int i) { 352 return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); 353 }; 354 auto dy_fn = [this](const complex64& x) { 355 return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); 356 }; 357 auto dx_fn = [this](const complex64& x, const complex64& dy) { 358 const complex64 y = std::tanh(x); 359 return dy * conjugate((one_ - y * y)); 360 }; 361 TestCWiseGrad<complex64>(TANH, x_fn, dy_fn, dx_fn); 362} 363 364TEST_F(CWiseUnaryGradTest, Sigmoid) { 365 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 366 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 367 auto dx_fn = [this](const float x, const float dy) { 368 const float y = 1.0 / (1.0 + std::exp(-x)); 369 return dy * y * (1.0 - y); 370 }; 371 TestCWiseGrad<float>(SIGMOID, x_fn, dy_fn, dx_fn); 372} 373 374TEST_F(CWiseUnaryGradTest, Sigmoid_Complex) { 375 auto x_fn = [this](const int i) { 376 return CRV({{1, 0}, {0, 0}, {2, -1}, {1, 2}, {3, 4}}); 377 }; 378 auto dy_fn = [this](const complex64& x) { 379 return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); 380 }; 381 auto dx_fn = [this](const complex64& x, const complex64& dy) { 382 const complex64 y = one_ / (one_ + std::exp(-x)); 383 return dy * conjugate(y * (one_ - y)); 384 }; 385 TestCWiseGrad<complex64>(SIGMOID, x_fn, dy_fn, dx_fn); 386} 387 388TEST_F(CWiseUnaryGradTest, Sign) { 389 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 390 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 391 auto dx_fn = [this](const float x, const float dy) { return 0.0; }; 392 TestCWiseGrad<float>(SIGN, x_fn, dy_fn, dx_fn); 393} 394 395TEST_F(CWiseUnaryGradTest, Sin) { 396 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 397 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 398 auto dx_fn = [this](const float x, const float dy) { 399 return dy * std::cos(x); 400 }; 401 TestCWiseGrad<float>(SIN, x_fn, dy_fn, dx_fn); 402} 403 404TEST_F(CWiseUnaryGradTest, Sin_Complex) { 405 auto x_fn = [this](const int i) { 406 return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); 407 }; 408 auto dy_fn = [this](const complex64& x) { 409 return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); 410 }; 411 auto dx_fn = [this](const complex64& x, const complex64& dy) { 412 return dy * conjugate(std::cos(x)); 413 }; 414 TestCWiseGrad<complex64>(SIN, x_fn, dy_fn, dx_fn); 415} 416 417TEST_F(CWiseUnaryGradTest, Cos) { 418 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 419 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 420 auto dx_fn = [this](const float x, const float dy) { 421 return dy * -1.0 * std::sin(x); 422 }; 423 TestCWiseGrad<float>(COS, x_fn, dy_fn, dx_fn); 424} 425 426TEST_F(CWiseUnaryGradTest, Cos_Complex) { 427 auto x_fn = [this](const int i) { 428 return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); 429 }; 430 auto dy_fn = [this](const complex64& x) { 431 return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); 432 }; 433 auto dx_fn = [this](const complex64& x, const complex64& dy) { 434 return dy * conjugate(-std::sin(x)); 435 }; 436 TestCWiseGrad<complex64>(COS, x_fn, dy_fn, dx_fn); 437} 438 439TEST_F(CWiseUnaryGradTest, Asin) { 440 auto x_fn = [this](const int i) { return RV({0, -0.5, 0.5, -1, 1}); }; 441 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 442 auto dx_fn = [this](const float x, const float dy) { 443 return dy * (1.0 / std::sqrt(1.0 - x * x)); 444 }; 445 TestCWiseGrad<float>(ASIN, x_fn, dy_fn, dx_fn); 446} 447 448TEST_F(CWiseUnaryGradTest, Asin_Complex) { 449 auto x_fn = [this](const int i) { 450 return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); 451 }; 452 auto dy_fn = [this](const complex64& x) { 453 return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); 454 }; 455 auto dx_fn = [this](const complex64& x, const complex64& dy) { 456 return dy / conjugate(std::sqrt(one_ - x * x)); 457 }; 458 // TODO(kbsriram) 459 // Enable test when the asin kernel supports complex numbers 460 if (false) { 461 TestCWiseGrad<complex64>(ASIN, x_fn, dy_fn, dx_fn); 462 } 463} 464 465TEST_F(CWiseUnaryGradTest, Acos) { 466 auto x_fn = [this](const int i) { return RV({0, -0.5, 0.5, -1, 1}); }; 467 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 468 auto dx_fn = [this](const float x, const float dy) { 469 return dy * (-1.0 / std::sqrt(1.0 - x * x)); 470 }; 471 TestCWiseGrad<float>(ACOS, x_fn, dy_fn, dx_fn); 472} 473 474TEST_F(CWiseUnaryGradTest, Acos_Complex) { 475 auto x_fn = [this](const int i) { 476 return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); 477 }; 478 auto dy_fn = [this](const complex64& x) { 479 return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); 480 }; 481 auto dx_fn = [this](const complex64& x, const complex64& dy) { 482 return dy / -conjugate(std::sqrt(one_ - x * x)); 483 }; 484 // TODO(kbsriram) 485 // Add test when the acos kernel supports complex numbers 486 if (false) { 487 TestCWiseGrad<complex64>(ACOS, x_fn, dy_fn, dx_fn); 488 } 489} 490 491TEST_F(CWiseUnaryGradTest, Tan) { 492 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 493 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 494 auto dx_fn = [this](const float x, const float dy) { 495 const float cosx = std::cos(x); 496 return dy * (1 / (cosx * cosx)); 497 }; 498 TestCWiseGrad<float>(TAN, x_fn, dy_fn, dx_fn); 499} 500 501TEST_F(CWiseUnaryGradTest, Tan_Complex) { 502 auto x_fn = [this](const int i) { 503 return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); 504 }; 505 auto dy_fn = [this](const complex64& x) { 506 return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); 507 }; 508 auto dx_fn = [this](const complex64& x, const complex64& dy) { 509 const complex64 cosx = std::cos(x); 510 return dy / conjugate(cosx * cosx); 511 }; 512 // TODO(kbsriram) 513 // Enable when tan kernel supports complex inputs 514 if (false) { 515 TestCWiseGrad<complex64>(TAN, x_fn, dy_fn, dx_fn); 516 } 517} 518 519TEST_F(CWiseUnaryGradTest, Atan) { 520 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 521 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 522 auto dx_fn = [this](const float x, const float dy) { 523 return dy * (1 / (1 + x * x)); 524 }; 525 TestCWiseGrad<float>(ATAN, x_fn, dy_fn, dx_fn); 526} 527 528TEST_F(CWiseUnaryGradTest, Atan_Complex) { 529 auto x_fn = [this](const int i) { 530 return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); 531 }; 532 auto dy_fn = [this](const complex64& x) { 533 return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); 534 }; 535 auto dx_fn = [this](const complex64& x, const complex64& dy) { 536 return dy / (one_ + x * x); 537 }; 538 // TODO(kbsriram) 539 // Add test when the atan kernel supports complex numbers 540 if (false) { 541 TestCWiseGrad<complex64>(ATAN, x_fn, dy_fn, dx_fn); 542 } 543} 544 545class CWiseUnaryComplexGradTest : public ::testing::Test { 546 protected: 547 CWiseUnaryComplexGradTest() 548 : scope_(Scope::NewRootScope().WithDevice("/cpu:0")) {} 549 550 enum UnaryOpType { REAL, IMAG, CONJ }; 551 552 void TestCWiseGradComplex(UnaryOpType op_type, const Tensor& x, 553 const Tensor& dy, const Tensor& dx_expected) { 554 Output y; 555 switch (op_type) { 556 case REAL: 557 y = Real(scope_, x); 558 break; 559 case IMAG: 560 y = Imag(scope_, x); 561 break; 562 case CONJ: 563 y = Conj(scope_, x); 564 break; 565 } 566 567 std::vector<Output> grad_outputs; 568 TF_ASSERT_OK(test::CallGradFunction( 569 scope_, Operation(y.node()), {ops::Const(scope_, dy)}, &grad_outputs)); 570 Tensor dx; 571 test::GetTensor(scope_, grad_outputs[0], &dx); 572 test::ExpectClose(dx, dx_expected); 573 } 574 575 Scope scope_; 576}; 577 578TEST_F(CWiseUnaryComplexGradTest, Real) { 579 Tensor x = test::AsTensor<complex64>( 580 {{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3}); 581 Tensor dy = test::AsTensor<float>({11, -12, 13, -14, 15, -16}, {2, 3}); 582 Tensor dx_expected = test::AsTensor<complex64>( 583 {{11, 0}, {-12, 0}, {13, 0}, {-14, 0}, {15, 0}, {-16, 0}}, {2, 3}); 584 TestCWiseGradComplex(REAL, x, dy, dx_expected); 585} 586 587TEST_F(CWiseUnaryComplexGradTest, Imag) { 588 Tensor x = test::AsTensor<complex64>( 589 {{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3}); 590 Tensor dy = test::AsTensor<float>({11, -12, 13, -14, 15, -16}, {2, 3}); 591 Tensor dx_expected = test::AsTensor<complex64>( 592 {{0, 11}, {0, -12}, {0, 13}, {0, -14}, {0, 15}, {0, -16}}, {2, 3}); 593 TestCWiseGradComplex(IMAG, x, dy, dx_expected); 594} 595 596TEST_F(CWiseUnaryComplexGradTest, Conj) { 597 Tensor x = test::AsTensor<complex64>( 598 {{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3}); 599 Tensor dy = test::AsTensor<complex64>( 600 {{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3}); 601 Tensor dx_expected = test::AsTensor<complex64>( 602 {{1, 1}, {-2, -2}, {3, 3}, {-4, -4}, {8, 8}, {-9, -9}}, {2, 3}); 603 TestCWiseGradComplex(CONJ, x, dy, dx_expected); 604} 605 606class MathGradTest : public ::testing::Test { 607 protected: 608 MathGradTest() : root_(Scope::NewRootScope().WithDevice("/cpu:0")) {} 609 610 void TestMatMulGrad(const bool is_batch, const bool t_x, const bool t_y) { 611 // Generate random test data. 612 std::vector<Tensor> data; 613 RandMatMulGradData(is_batch, t_x, t_y, &data); 614 auto x = Const(root_, data[0]); 615 auto y = Const(root_, data[1]); 616 auto dz = Const(root_, data[2]); 617 618 std::vector<Tensor> grad_outputs; 619 ComputeMatMulGrad(is_batch, x, t_x, y, t_y, dz, &grad_outputs); 620 621 if (!t_x && !t_y) { 622 test::ExpectClose(grad_outputs[0], 623 ComputeMatMul(is_batch, dz, false, y, true)); 624 test::ExpectClose(grad_outputs[1], 625 ComputeMatMul(is_batch, x, true, dz, false)); 626 } else if (t_x && !t_y) { 627 test::ExpectClose(grad_outputs[0], 628 ComputeMatMul(is_batch, y, false, dz, true)); 629 test::ExpectClose(grad_outputs[1], 630 ComputeMatMul(is_batch, x, false, dz, false)); 631 } else if (!t_x && t_y) { 632 test::ExpectClose(grad_outputs[0], 633 ComputeMatMul(is_batch, dz, false, y, false)); 634 test::ExpectClose(grad_outputs[1], 635 ComputeMatMul(is_batch, dz, true, x, false)); 636 } else { 637 test::ExpectClose(grad_outputs[0], 638 ComputeMatMul(is_batch, y, true, dz, true)); 639 test::ExpectClose(grad_outputs[1], 640 ComputeMatMul(is_batch, dz, true, x, true)); 641 } 642 } 643 644 void ComputeMatMulGrad(const bool is_batch, const Output& x, const bool t_x, 645 const Output& y, const bool t_y, const Output& dz, 646 std::vector<Tensor>* out) { 647 // Compute forward MatMul: z = MatMul(x, y). 648 Output z; 649 if (is_batch) { 650 z = BatchMatMul(root_, x, y, BatchMatMul::AdjX(t_x).AdjY(t_y)); 651 } else { 652 z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y)); 653 } 654 TF_ASSERT_OK(root_.status()); 655 CHECK_NOTNULL(z.node()); 656 std::vector<Output> grad_outputs; 657 // Call MatMulGrad which populates 'grad_outputs'. 658 TF_ASSERT_OK(test::CallGradFunction(root_, Operation(z.node()), {dz}, 659 &grad_outputs)); 660 ASSERT_EQ(2, grad_outputs.size()); 661 // Run graph and return MatMul gradient tensors for 'dx' and 'dy' in 'out'. 662 test::GetTensors(root_, {grad_outputs[0], grad_outputs[1]}, out); 663 } 664 665 Tensor ComputeMatMul(const bool is_batch, const Output& x, const bool t_x, 666 const Output& y, const bool t_y) { 667 Output z; 668 if (is_batch) { 669 z = BatchMatMul(root_, x, y, BatchMatMul::AdjX(t_x).AdjY(t_y)); 670 } else { 671 z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y)); 672 } 673 TF_EXPECT_OK(root_.status()); 674 Tensor out; 675 test::GetTensor(root_, z, &out); 676 return out; 677 } 678 679 void RandMatMulGradData(const bool is_batch, const bool tx, const bool ty, 680 std::vector<Tensor>* data) { 681 // Choose a random batch size in [1, 4] 682 const int b = 1 + (random::New64() % 4); 683 // z = MatMul(x, y) 684 const int m = Rand(); 685 const int k = Rand(); 686 const int n = Rand(); 687 688 TensorShape x_shape; 689 if (is_batch) { 690 // x.shape = [b, m, k] 691 x_shape = tx ? TensorShape({b, k, m}) : TensorShape({b, m, k}); 692 } else { 693 // x.shape = [m, k] 694 x_shape = tx ? TensorShape({k, m}) : TensorShape({m, k}); 695 } 696 data->emplace_back(DT_FLOAT, x_shape); 697 RandTensor(&data->back()); 698 699 TensorShape y_shape; 700 if (is_batch) { 701 // y.shape = [b, k, n] 702 y_shape = ty ? TensorShape({b, n, k}) : TensorShape({b, k, n}); 703 } else { 704 // y.shape = [k, n] 705 y_shape = ty ? TensorShape({n, k}) : TensorShape({k, n}); 706 } 707 data->emplace_back(DT_FLOAT, y_shape); 708 RandTensor(&data->back()); 709 710 TensorShape z_shape; 711 if (is_batch) { 712 // z.shape = [b, m, n] 713 z_shape = TensorShape({b, m, n}); 714 } else { 715 // z.shape = [m, n] 716 z_shape = TensorShape({m, n}); 717 } 718 data->emplace_back(DT_FLOAT, z_shape); 719 RandTensor(&data->back()); 720 } 721 722 void RandTensor(Tensor* t) { 723 test::FillFn<float>( 724 t, [this](const int i) { return static_cast<float>(Rand()); }); 725 } 726 727 int Rand() { return 1 + (random::New64() % 10); } 728 729 Scope root_; 730}; 731 732TEST_F(MathGradTest, MatMulGrad_NoTranspose) { 733 TestMatMulGrad(false, false, false); 734} 735 736TEST_F(MathGradTest, MatMulGrad_TransposeX) { 737 TestMatMulGrad(false, true, false); 738} 739 740TEST_F(MathGradTest, MatMulGrad_TransposeY) { 741 TestMatMulGrad(false, false, true); 742} 743 744TEST_F(MathGradTest, MatMulGrad_TransposeX_TransposeY) { 745 TestMatMulGrad(false, true, true); 746} 747 748TEST_F(MathGradTest, BatchMatMulGrad_NoTranspose) { 749 TestMatMulGrad(true, false, false); 750} 751 752TEST_F(MathGradTest, BatchMatMulGrad_TransposeX) { 753 TestMatMulGrad(true, true, false); 754} 755 756TEST_F(MathGradTest, BatchMatMulGrad_TransposeY) { 757 TestMatMulGrad(true, false, true); 758} 759 760TEST_F(MathGradTest, BatchMatMulGrad_TransposeX_TransposeY) { 761 TestMatMulGrad(true, true, true); 762} 763 764} // namespace 765} // namespace tensorflow 766