math_grad_test.cc revision 1d0b8c007b8bc7f77dd63c74f02d87185071f038
1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include "tensorflow/cc/framework/grad_op_registry.h" 17#include "tensorflow/cc/framework/testutil.h" 18#include "tensorflow/cc/gradients/grad_testutil.h" 19#include "tensorflow/cc/ops/standard_ops.h" 20#include "tensorflow/core/framework/tensor_testutil.h" 21#include "tensorflow/core/lib/core/status_test_util.h" 22#include "tensorflow/core/lib/random/random.h" 23 24namespace tensorflow { 25using namespace ops; // NOLINT(build/namespaces) 26 27namespace { 28 29// TODO(andydavis) Test gradient function against numeric gradients output. 30// TODO(andydavis) As more gradients are added move common test functions 31// to a testutil library. 32 33class CWiseUnaryGradTest : public ::testing::Test { 34 protected: 35 CWiseUnaryGradTest() : scope_(Scope::NewRootScope().WithDevice("/cpu:0")) {} 36 37 enum UnaryOpType { 38 ABS, 39 NEG, 40 INV, 41 SQUARE, 42 SQRT, 43 RSQRT, 44 EXP, 45 EXPM1, 46 LOG, 47 LOG1P, 48 TANH, 49 SIGMOID, 50 SIGN, 51 SIN, 52 COS, 53 ASIN, 54 ACOS, 55 TAN, 56 ATAN 57 }; 58 59 void TestCWiseGrad(UnaryOpType op_type, const std::function<float(int)>& x_fn, 60 const std::function<float(float)>& dy_fn, 61 const std::function<float(float, float)>& dx_fn) { 62 Tensor x(DT_FLOAT, {2, 3, 2}); 63 auto x_flat = x.flat<float>(); 64 for (int i = 0; i < x_flat.size(); ++i) { 65 x_flat(i) = x_fn(i); 66 } 67 68 Tensor dy(DT_FLOAT, {2, 3, 2}); 69 auto dy_flat = dy.flat<float>(); 70 for (int i = 0; i < dy_flat.size(); ++i) { 71 dy_flat(i) = dy_fn(x_flat(i)); 72 } 73 74 Tensor dx(DT_FLOAT, {2, 3, 2}); 75 auto dx_flat = dx.flat<float>(); 76 for (int i = 0; i < dx_flat.size(); ++i) { 77 dx_flat(i) = dx_fn(x_flat(i), dy_flat(i)); 78 } 79 80 Output y; 81 switch (op_type) { 82 case ABS: 83 y = Abs(scope_, x); 84 break; 85 case NEG: 86 y = Neg(scope_, x); 87 break; 88 case INV: 89 y = Reciprocal(scope_, x); 90 break; 91 case SQUARE: 92 y = Square(scope_, x); 93 break; 94 case SQRT: 95 y = Sqrt(scope_, x); 96 break; 97 case RSQRT: 98 y = Rsqrt(scope_, x); 99 break; 100 case EXP: 101 y = Exp(scope_, x); 102 break; 103 case EXPM1: 104 y = Expm1(scope_, x); 105 break; 106 case LOG: 107 y = Log(scope_, x); 108 break; 109 case LOG1P: 110 y = Log1p(scope_, x); 111 break; 112 case TANH: 113 y = Tanh(scope_, x); 114 break; 115 case SIGMOID: 116 y = Sigmoid(scope_, x); 117 break; 118 case SIGN: 119 y = Sign(scope_, x); 120 break; 121 case SIN: 122 y = Sin(scope_, x); 123 break; 124 case COS: 125 y = Cos(scope_, x); 126 break; 127 case ASIN: 128 y = Asin(scope_, x); 129 break; 130 case ACOS: 131 y = Acos(scope_, x); 132 break; 133 case TAN: 134 y = Tan(scope_, x); 135 break; 136 case ATAN: 137 y = Atan(scope_, x); 138 break; 139 } 140 141 std::vector<Output> grad_outputs; 142 TF_ASSERT_OK(test::CallGradFunction( 143 scope_, Operation(y.node()), {ops::Const(scope_, dy)}, &grad_outputs)); 144 Tensor output; 145 test::GetTensor(scope_, grad_outputs[0], &output); 146 test::ExpectClose(output, dx); 147 } 148 149 float RV(std::vector<float> v) { return v[random::New64() % v.size()]; } 150 151 Scope scope_; 152}; 153 154TEST_F(CWiseUnaryGradTest, Abs) { 155 auto x_fn = [this](const int i) { return RV({-1, 0, 1}); }; 156 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 157 auto dx_fn = [this](const float x, const float dy) { return x * dy; }; 158 TestCWiseGrad(ABS, x_fn, dy_fn, dx_fn); 159} 160 161TEST_F(CWiseUnaryGradTest, Neg) { 162 auto x_fn = [this](const int i) { return RV({-1, 0, 1}); }; 163 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 164 auto dx_fn = [this](const float x, const float dy) { return -dy; }; 165 TestCWiseGrad(NEG, x_fn, dy_fn, dx_fn); 166} 167 168TEST_F(CWiseUnaryGradTest, Reciprocal) { 169 auto x_fn = [this](const int i) { return RV({-1, 1, -2, 2, -3, 3, -4, 4}); }; 170 auto dy_fn = [this](const float x) { return RV({0, -2, 2, -3, 3, -4, 4}); }; 171 auto dx_fn = [this](const float x, const float dy) { 172 return -(1 / (x * x)) * dy; 173 }; 174 TestCWiseGrad(INV, x_fn, dy_fn, dx_fn); 175} 176 177TEST_F(CWiseUnaryGradTest, Square) { 178 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 179 auto dy_fn = [this](const float x) { return RV({0, -7, 7, -8, 8, -9, 9}); }; 180 auto dx_fn = [this](const float x, const float dy) { return 2 * x * dy; }; 181 TestCWiseGrad(SQUARE, x_fn, dy_fn, dx_fn); 182} 183 184TEST_F(CWiseUnaryGradTest, Sqrt) { 185 auto x_fn = [this](const int i) { return RV({0, 1, 2, 3, 4, 5, 6, 7}); }; 186 auto dy_fn = [this](const float x) { return x + RV({8, 9, 10, 11, 12, 13}); }; 187 auto dx_fn = [this](const float x, const float dy) { 188 return dy * 0.5 * (1.0 / std::sqrt(x)); 189 }; 190 TestCWiseGrad(SQRT, x_fn, dy_fn, dx_fn); 191} 192 193TEST_F(CWiseUnaryGradTest, Rsqrt) { 194 auto x_fn = [this](const int i) { return RV({1, 2, 3, 4, 5, 6, 7, 8}); }; 195 auto dy_fn = [this](const float x) { return x + RV({8, 9, 10, 11, 12, 13}); }; 196 auto dx_fn = [this](const float x, const float dy) { 197 return dy * -0.5 * (1 / std::sqrt(x)) * (1 / x); 198 }; 199 TestCWiseGrad(RSQRT, x_fn, dy_fn, dx_fn); 200} 201 202TEST_F(CWiseUnaryGradTest, Exp) { 203 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 204 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 205 auto dx_fn = [this](const float x, const float dy) { 206 return dy * std::exp(x); 207 }; 208 TestCWiseGrad(EXP, x_fn, dy_fn, dx_fn); 209} 210 211TEST_F(CWiseUnaryGradTest, Expm1) { 212 auto x_fn = [this](const int i) { return RV({0, -1, 1e-6, 1, -2, 3, 100}); }; 213 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 214 auto dx_fn = [this](const float x, const float dy) { 215 return dy * std::exp(x); 216 }; 217 TestCWiseGrad(EXPM1, x_fn, dy_fn, dx_fn); 218} 219 220TEST_F(CWiseUnaryGradTest, Log) { 221 auto x_fn = [this](const int i) { return RV({-1, 1, -2, 2, -3, 3, -4, 4}); }; 222 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 223 auto dx_fn = [this](const float x, const float dy) { return dy * (1.0 / x); }; 224 TestCWiseGrad(LOG, x_fn, dy_fn, dx_fn); 225} 226 227TEST_F(CWiseUnaryGradTest, Log1p) { 228 auto x_fn = [this](const int i) { return RV({0, 1e-6, 1, 2, 3, 4, 100}); }; 229 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 230 auto dx_fn = [this](const float x, const float dy) { 231 return dy * (1.0 / (1.0 + x)); 232 }; 233 TestCWiseGrad(LOG1P, x_fn, dy_fn, dx_fn); 234} 235 236TEST_F(CWiseUnaryGradTest, Tanh) { 237 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 238 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 239 auto dx_fn = [this](const float x, const float dy) { 240 const float y = std::tanh(x); 241 return dy * (1.0 - y * y); 242 }; 243 TestCWiseGrad(TANH, x_fn, dy_fn, dx_fn); 244} 245 246TEST_F(CWiseUnaryGradTest, Sigmoid) { 247 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 248 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 249 auto dx_fn = [this](const float x, const float dy) { 250 const float y = 1.0 / (1.0 + std::exp(-x)); 251 return dy * y * (1.0 - y); 252 }; 253 TestCWiseGrad(SIGMOID, x_fn, dy_fn, dx_fn); 254} 255 256TEST_F(CWiseUnaryGradTest, Sign) { 257 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 258 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 259 auto dx_fn = [this](const float x, const float dy) { return 0.0; }; 260 TestCWiseGrad(SIGN, x_fn, dy_fn, dx_fn); 261} 262 263TEST_F(CWiseUnaryGradTest, Sin) { 264 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 265 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 266 auto dx_fn = [this](const float x, const float dy) { 267 return dy * std::cos(x); 268 }; 269 TestCWiseGrad(SIN, x_fn, dy_fn, dx_fn); 270} 271 272TEST_F(CWiseUnaryGradTest, Cos) { 273 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 274 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 275 auto dx_fn = [this](const float x, const float dy) { 276 return dy * -1.0 * std::sin(x); 277 }; 278 TestCWiseGrad(COS, x_fn, dy_fn, dx_fn); 279} 280 281TEST_F(CWiseUnaryGradTest, Asin) { 282 auto x_fn = [this](const int i) { return RV({0, -0.5, 0.5, -1, 1}); }; 283 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 284 auto dx_fn = [this](const float x, const float dy) { 285 return dy * (1.0 / std::sqrt(1.0 - x * x)); 286 }; 287 TestCWiseGrad(ASIN, x_fn, dy_fn, dx_fn); 288} 289 290TEST_F(CWiseUnaryGradTest, Acos) { 291 auto x_fn = [this](const int i) { return RV({0, -0.5, 0.5, -1, 1}); }; 292 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 293 auto dx_fn = [this](const float x, const float dy) { 294 return dy * (-1.0 / std::sqrt(1.0 - x * x)); 295 }; 296 TestCWiseGrad(ACOS, x_fn, dy_fn, dx_fn); 297} 298 299TEST_F(CWiseUnaryGradTest, Tan) { 300 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 301 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 302 auto dx_fn = [this](const float x, const float dy) { 303 const float cosx = std::cos(x); 304 return dy * (1 / (cosx * cosx)); 305 }; 306 TestCWiseGrad(TAN, x_fn, dy_fn, dx_fn); 307} 308 309TEST_F(CWiseUnaryGradTest, Atan) { 310 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 311 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 312 auto dx_fn = [this](const float x, const float dy) { 313 return dy * (1 / (1 + x * x)); 314 }; 315 TestCWiseGrad(ATAN, x_fn, dy_fn, dx_fn); 316} 317 318class CWiseUnaryComplexGradTest : public ::testing::Test { 319 protected: 320 CWiseUnaryComplexGradTest() 321 : scope_(Scope::NewRootScope().WithDevice("/cpu:0")) {} 322 323 enum UnaryOpType { REAL, IMAG, CONJ }; 324 325 void TestCWiseGradComplex(UnaryOpType op_type, const Tensor& x, 326 const Tensor& dy, const Tensor& dx_expected) { 327 Output y; 328 switch (op_type) { 329 case REAL: 330 y = Real(scope_, x); 331 break; 332 case IMAG: 333 y = Imag(scope_, x); 334 break; 335 case CONJ: 336 y = Conj(scope_, x); 337 break; 338 } 339 340 std::vector<Output> grad_outputs; 341 TF_ASSERT_OK(test::CallGradFunction( 342 scope_, Operation(y.node()), {ops::Const(scope_, dy)}, &grad_outputs)); 343 Tensor dx; 344 test::GetTensor(scope_, grad_outputs[0], &dx); 345 test::ExpectClose(dx, dx_expected); 346 } 347 348 Scope scope_; 349}; 350 351TEST_F(CWiseUnaryComplexGradTest, Real) { 352 Tensor x = test::AsTensor<complex64>( 353 {{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3}); 354 Tensor dy = test::AsTensor<float>({11, -12, 13, -14, 15, -16}, {2, 3}); 355 Tensor dx_expected = test::AsTensor<complex64>( 356 {{11, 0}, {-12, 0}, {13, 0}, {-14, 0}, {15, 0}, {-16, 0}}, {2, 3}); 357 TestCWiseGradComplex(REAL, x, dy, dx_expected); 358} 359 360TEST_F(CWiseUnaryComplexGradTest, Imag) { 361 Tensor x = test::AsTensor<complex64>( 362 {{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3}); 363 Tensor dy = test::AsTensor<float>({11, -12, 13, -14, 15, -16}, {2, 3}); 364 Tensor dx_expected = test::AsTensor<complex64>( 365 {{0, 11}, {0, -12}, {0, 13}, {0, -14}, {0, 15}, {0, -16}}, {2, 3}); 366 TestCWiseGradComplex(IMAG, x, dy, dx_expected); 367} 368 369TEST_F(CWiseUnaryComplexGradTest, Conj) { 370 Tensor x = test::AsTensor<complex64>( 371 {{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3}); 372 Tensor dy = test::AsTensor<complex64>( 373 {{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3}); 374 Tensor dx_expected = test::AsTensor<complex64>( 375 {{1, 1}, {-2, -2}, {3, 3}, {-4, -4}, {8, 8}, {-9, -9}}, {2, 3}); 376 TestCWiseGradComplex(CONJ, x, dy, dx_expected); 377} 378 379class MathGradTest : public ::testing::Test { 380 protected: 381 MathGradTest() : root_(Scope::NewRootScope().WithDevice("/cpu:0")) {} 382 383 void TestMatMulGrad(const bool is_batch, const bool t_x, const bool t_y) { 384 // Generate random test data. 385 std::vector<Tensor> data; 386 RandMatMulGradData(is_batch, t_x, t_y, &data); 387 auto x = Const(root_, data[0]); 388 auto y = Const(root_, data[1]); 389 auto dz = Const(root_, data[2]); 390 391 std::vector<Tensor> grad_outputs; 392 ComputeMatMulGrad(is_batch, x, t_x, y, t_y, dz, &grad_outputs); 393 394 if (!t_x && !t_y) { 395 test::ExpectClose(grad_outputs[0], 396 ComputeMatMul(is_batch, dz, false, y, true)); 397 test::ExpectClose(grad_outputs[1], 398 ComputeMatMul(is_batch, x, true, dz, false)); 399 } else if (t_x && !t_y) { 400 test::ExpectClose(grad_outputs[0], 401 ComputeMatMul(is_batch, y, false, dz, true)); 402 test::ExpectClose(grad_outputs[1], 403 ComputeMatMul(is_batch, x, false, dz, false)); 404 } else if (!t_x && t_y) { 405 test::ExpectClose(grad_outputs[0], 406 ComputeMatMul(is_batch, dz, false, y, false)); 407 test::ExpectClose(grad_outputs[1], 408 ComputeMatMul(is_batch, dz, true, x, false)); 409 } else { 410 test::ExpectClose(grad_outputs[0], 411 ComputeMatMul(is_batch, y, true, dz, true)); 412 test::ExpectClose(grad_outputs[1], 413 ComputeMatMul(is_batch, dz, true, x, true)); 414 } 415 } 416 417 void ComputeMatMulGrad(const bool is_batch, const Output& x, const bool t_x, 418 const Output& y, const bool t_y, const Output& dz, 419 std::vector<Tensor>* out) { 420 // Compute forward MatMul: z = MatMul(x, y). 421 Output z; 422 if (is_batch) { 423 z = BatchMatMul(root_, x, y, BatchMatMul::AdjX(t_x).AdjY(t_y)); 424 } else { 425 z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y)); 426 } 427 TF_ASSERT_OK(root_.status()); 428 CHECK_NOTNULL(z.node()); 429 std::vector<Output> grad_outputs; 430 // Call MatMulGrad which populates 'grad_outputs'. 431 TF_ASSERT_OK(test::CallGradFunction(root_, Operation(z.node()), {dz}, 432 &grad_outputs)); 433 ASSERT_EQ(2, grad_outputs.size()); 434 // Run graph and return MatMul gradient tensors for 'dx' and 'dy' in 'out'. 435 test::GetTensors(root_, {grad_outputs[0], grad_outputs[1]}, out); 436 } 437 438 Tensor ComputeMatMul(const bool is_batch, const Output& x, const bool t_x, 439 const Output& y, const bool t_y) { 440 Output z; 441 if (is_batch) { 442 z = BatchMatMul(root_, x, y, BatchMatMul::AdjX(t_x).AdjY(t_y)); 443 } else { 444 z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y)); 445 } 446 TF_EXPECT_OK(root_.status()); 447 Tensor out; 448 test::GetTensor(root_, z, &out); 449 return out; 450 } 451 452 void RandMatMulGradData(const bool is_batch, const bool tx, const bool ty, 453 std::vector<Tensor>* data) { 454 // Choose a random batch size in [1, 4] 455 const int b = 1 + (random::New64() % 4); 456 // z = MatMul(x, y) 457 const int m = Rand(); 458 const int k = Rand(); 459 const int n = Rand(); 460 461 TensorShape x_shape; 462 if (is_batch) { 463 // x.shape = [b, m, k] 464 x_shape = tx ? TensorShape({b, k, m}) : TensorShape({b, m, k}); 465 } else { 466 // x.shape = [m, k] 467 x_shape = tx ? TensorShape({k, m}) : TensorShape({m, k}); 468 } 469 data->emplace_back(DT_FLOAT, x_shape); 470 RandTensor(&data->back()); 471 472 TensorShape y_shape; 473 if (is_batch) { 474 // y.shape = [b, k, n] 475 y_shape = ty ? TensorShape({b, n, k}) : TensorShape({b, k, n}); 476 } else { 477 // y.shape = [k, n] 478 y_shape = ty ? TensorShape({n, k}) : TensorShape({k, n}); 479 } 480 data->emplace_back(DT_FLOAT, y_shape); 481 RandTensor(&data->back()); 482 483 TensorShape z_shape; 484 if (is_batch) { 485 // z.shape = [b, m, n] 486 z_shape = TensorShape({b, m, n}); 487 } else { 488 // z.shape = [m, n] 489 z_shape = TensorShape({m, n}); 490 } 491 data->emplace_back(DT_FLOAT, z_shape); 492 RandTensor(&data->back()); 493 } 494 495 void RandTensor(Tensor* t) { 496 test::FillFn<float>( 497 t, [this](const int i) { return static_cast<float>(Rand()); }); 498 } 499 500 int Rand() { return 1 + (random::New64() % 10); } 501 502 Scope root_; 503}; 504 505TEST_F(MathGradTest, MatMulGrad_NoTranspose) { 506 TestMatMulGrad(false, false, false); 507} 508 509TEST_F(MathGradTest, MatMulGrad_TransposeX) { 510 TestMatMulGrad(false, true, false); 511} 512 513TEST_F(MathGradTest, MatMulGrad_TransposeY) { 514 TestMatMulGrad(false, false, true); 515} 516 517TEST_F(MathGradTest, MatMulGrad_TransposeX_TransposeY) { 518 TestMatMulGrad(false, true, true); 519} 520 521TEST_F(MathGradTest, BatchMatMulGrad_NoTranspose) { 522 TestMatMulGrad(true, false, false); 523} 524 525TEST_F(MathGradTest, BatchMatMulGrad_TransposeX) { 526 TestMatMulGrad(true, true, false); 527} 528 529TEST_F(MathGradTest, BatchMatMulGrad_TransposeY) { 530 TestMatMulGrad(true, false, true); 531} 532 533TEST_F(MathGradTest, BatchMatMulGrad_TransposeX_TransposeY) { 534 TestMatMulGrad(true, true, true); 535} 536 537} // namespace 538} // namespace tensorflow 539