math_grad_test.cc revision fb01ebb8c38b2d274f6fe9a7115b2362828a452e
1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include "tensorflow/cc/framework/grad_op_registry.h" 17#include "tensorflow/cc/framework/testutil.h" 18#include "tensorflow/cc/gradients/grad_testutil.h" 19#include "tensorflow/cc/ops/standard_ops.h" 20#include "tensorflow/core/framework/tensor_testutil.h" 21#include "tensorflow/core/lib/core/status_test_util.h" 22#include "tensorflow/core/lib/random/random.h" 23 24namespace tensorflow { 25using namespace ops; // NOLINT(build/namespaces) 26 27namespace { 28 29// TODO(andydavis) Test gradient function against numeric gradients output. 30// TODO(andydavis) As more gradients are added move common test functions 31// to a testutil library. 32 33class CWiseUnaryGradTest : public ::testing::Test { 34 protected: 35 CWiseUnaryGradTest() : scope_(Scope::NewRootScope().WithDevice("/cpu:0")) {} 36 37 enum UnaryOpType { 38 ABS, 39 NEG, 40 INV, 41 SQUARE, 42 SQRT, 43 RSQRT, 44 EXP, 45 LOG, 46 TANH, 47 SIGMOID, 48 SIGN, 49 SIN, 50 COS, 51 ASIN, 52 ACOS, 53 TAN, 54 ATAN 55 }; 56 57 void TestCWiseGrad(UnaryOpType op_type, std::function<float(int)> x_fn, 58 std::function<float(float)> dy_fn, 59 std::function<float(float, float)> dx_fn) { 60 Tensor x(DT_FLOAT, {2, 3, 2}); 61 auto x_flat = x.flat<float>(); 62 for (int i = 0; i < x_flat.size(); ++i) { 63 x_flat(i) = x_fn(i); 64 } 65 66 Tensor dy(DT_FLOAT, {2, 3, 2}); 67 auto dy_flat = dy.flat<float>(); 68 for (int i = 0; i < dy_flat.size(); ++i) { 69 dy_flat(i) = dy_fn(x_flat(i)); 70 } 71 72 Tensor dx(DT_FLOAT, {2, 3, 2}); 73 auto dx_flat = dx.flat<float>(); 74 for (int i = 0; i < dx_flat.size(); ++i) { 75 dx_flat(i) = dx_fn(x_flat(i), dy_flat(i)); 76 } 77 78 Output y; 79 switch (op_type) { 80 case ABS: 81 y = Abs(scope_, x); 82 break; 83 case NEG: 84 y = Neg(scope_, x); 85 break; 86 case INV: 87 y = Reciprocal(scope_, x); 88 break; 89 case SQUARE: 90 y = Square(scope_, x); 91 break; 92 case SQRT: 93 y = Sqrt(scope_, x); 94 break; 95 case RSQRT: 96 y = Rsqrt(scope_, x); 97 break; 98 case EXP: 99 y = Exp(scope_, x); 100 break; 101 case LOG: 102 y = Log(scope_, x); 103 break; 104 case TANH: 105 y = Tanh(scope_, x); 106 break; 107 case SIGMOID: 108 y = Sigmoid(scope_, x); 109 break; 110 case SIGN: 111 y = Sign(scope_, x); 112 break; 113 case SIN: 114 y = Sin(scope_, x); 115 break; 116 case COS: 117 y = Cos(scope_, x); 118 break; 119 case ASIN: 120 y = Asin(scope_, x); 121 break; 122 case ACOS: 123 y = Acos(scope_, x); 124 break; 125 case TAN: 126 y = Tan(scope_, x); 127 break; 128 case ATAN: 129 y = Atan(scope_, x); 130 break; 131 } 132 133 std::vector<Output> grad_outputs; 134 TF_ASSERT_OK(test::CallGradFunction( 135 scope_, Operation(y.node()), {ops::Const(scope_, dy)}, &grad_outputs)); 136 Tensor output; 137 test::GetTensor(scope_, grad_outputs[0], &output); 138 test::ExpectClose(output, dx); 139 } 140 141 float RV(std::vector<float> v) { return v[random::New64() % v.size()]; } 142 143 Scope scope_; 144}; 145 146TEST_F(CWiseUnaryGradTest, Abs) { 147 auto x_fn = [this](const int i) { return RV({-1, 0, 1}); }; 148 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 149 auto dx_fn = [this](const float x, const float dy) { return x * dy; }; 150 TestCWiseGrad(ABS, x_fn, dy_fn, dx_fn); 151} 152 153TEST_F(CWiseUnaryGradTest, Neg) { 154 auto x_fn = [this](const int i) { return RV({-1, 0, 1}); }; 155 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 156 auto dx_fn = [this](const float x, const float dy) { return -dy; }; 157 TestCWiseGrad(NEG, x_fn, dy_fn, dx_fn); 158} 159 160TEST_F(CWiseUnaryGradTest, Reciprocal) { 161 auto x_fn = [this](const int i) { return RV({-1, 1, -2, 2, -3, 3, -4, 4}); }; 162 auto dy_fn = [this](const float x) { return RV({0, -2, 2, -3, 3, -4, 4}); }; 163 auto dx_fn = [this](const float x, const float dy) { 164 return -(1 / (x * x)) * dy; 165 }; 166 TestCWiseGrad(INV, x_fn, dy_fn, dx_fn); 167} 168 169TEST_F(CWiseUnaryGradTest, Square) { 170 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 171 auto dy_fn = [this](const float x) { return RV({0, -7, 7, -8, 8, -9, 9}); }; 172 auto dx_fn = [this](const float x, const float dy) { return 2 * x * dy; }; 173 TestCWiseGrad(SQUARE, x_fn, dy_fn, dx_fn); 174} 175 176TEST_F(CWiseUnaryGradTest, Sqrt) { 177 auto x_fn = [this](const int i) { return RV({0, 1, 2, 3, 4, 5, 6, 7}); }; 178 auto dy_fn = [this](const float x) { return x + RV({8, 9, 10, 11, 12, 13}); }; 179 auto dx_fn = [this](const float x, const float dy) { 180 return dy * 0.5 * (1.0 / std::sqrt(x)); 181 }; 182 TestCWiseGrad(SQRT, x_fn, dy_fn, dx_fn); 183} 184 185TEST_F(CWiseUnaryGradTest, Rsqrt) { 186 auto x_fn = [this](const int i) { return RV({1, 2, 3, 4, 5, 6, 7, 8}); }; 187 auto dy_fn = [this](const float x) { return x + RV({8, 9, 10, 11, 12, 13}); }; 188 auto dx_fn = [this](const float x, const float dy) { 189 return dy * -0.5 * (1 / std::sqrt(x)) * (1 / x); 190 }; 191 TestCWiseGrad(RSQRT, x_fn, dy_fn, dx_fn); 192} 193 194TEST_F(CWiseUnaryGradTest, Exp) { 195 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 196 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 197 auto dx_fn = [this](const float x, const float dy) { 198 return dy * std::exp(x); 199 }; 200 TestCWiseGrad(EXP, x_fn, dy_fn, dx_fn); 201} 202 203TEST_F(CWiseUnaryGradTest, Log) { 204 auto x_fn = [this](const int i) { return RV({-1, 1, -2, 2, -3, 3, -4, 4}); }; 205 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 206 auto dx_fn = [this](const float x, const float dy) { return dy * (1.0 / x); }; 207 TestCWiseGrad(LOG, x_fn, dy_fn, dx_fn); 208} 209 210TEST_F(CWiseUnaryGradTest, Tanh) { 211 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 212 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 213 auto dx_fn = [this](const float x, const float dy) { 214 const float y = std::tanh(x); 215 return dy * (1.0 - y * y); 216 }; 217 TestCWiseGrad(TANH, x_fn, dy_fn, dx_fn); 218} 219 220TEST_F(CWiseUnaryGradTest, Sigmoid) { 221 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 222 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 223 auto dx_fn = [this](const float x, const float dy) { 224 const float y = 1.0 / (1.0 + std::exp(-x)); 225 return dy * y * (1.0 - y); 226 }; 227 TestCWiseGrad(SIGMOID, x_fn, dy_fn, dx_fn); 228} 229 230TEST_F(CWiseUnaryGradTest, Sign) { 231 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 232 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 233 auto dx_fn = [this](const float x, const float dy) { return 0.0; }; 234 TestCWiseGrad(SIGN, x_fn, dy_fn, dx_fn); 235} 236 237TEST_F(CWiseUnaryGradTest, Sin) { 238 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 239 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 240 auto dx_fn = [this](const float x, const float dy) { 241 return dy * std::cos(x); 242 }; 243 TestCWiseGrad(SIN, x_fn, dy_fn, dx_fn); 244} 245 246TEST_F(CWiseUnaryGradTest, Cos) { 247 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 248 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 249 auto dx_fn = [this](const float x, const float dy) { 250 return dy * -1.0 * std::sin(x); 251 }; 252 TestCWiseGrad(COS, x_fn, dy_fn, dx_fn); 253} 254 255TEST_F(CWiseUnaryGradTest, Asin) { 256 auto x_fn = [this](const int i) { return RV({0, -0.5, 0.5, -1, 1}); }; 257 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 258 auto dx_fn = [this](const float x, const float dy) { 259 return dy * (1.0 / std::sqrt(1.0 - x * x)); 260 }; 261 TestCWiseGrad(ASIN, x_fn, dy_fn, dx_fn); 262} 263 264TEST_F(CWiseUnaryGradTest, Acos) { 265 auto x_fn = [this](const int i) { return RV({0, -0.5, 0.5, -1, 1}); }; 266 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 267 auto dx_fn = [this](const float x, const float dy) { 268 return dy * (-1.0 / std::sqrt(1.0 - x * x)); 269 }; 270 TestCWiseGrad(ACOS, x_fn, dy_fn, dx_fn); 271} 272 273TEST_F(CWiseUnaryGradTest, Tan) { 274 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 275 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 276 auto dx_fn = [this](const float x, const float dy) { 277 const float cosx = std::cos(x); 278 return dy * (1 / (cosx * cosx)); 279 }; 280 TestCWiseGrad(TAN, x_fn, dy_fn, dx_fn); 281} 282 283TEST_F(CWiseUnaryGradTest, Atan) { 284 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; 285 auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; 286 auto dx_fn = [this](const float x, const float dy) { 287 return dy * (1 / (1 + x * x)); 288 }; 289 TestCWiseGrad(ATAN, x_fn, dy_fn, dx_fn); 290} 291 292class CWiseUnaryComplexGradTest : public ::testing::Test { 293 protected: 294 CWiseUnaryComplexGradTest() 295 : scope_(Scope::NewRootScope().WithDevice("/cpu:0")) {} 296 297 enum UnaryOpType { REAL, IMAG, CONJ }; 298 299 void TestCWiseGradComplex(UnaryOpType op_type, const Tensor& x, 300 const Tensor& dy, const Tensor& dx_expected) { 301 Output y; 302 switch (op_type) { 303 case REAL: 304 y = Real(scope_, x); 305 break; 306 case IMAG: 307 y = Imag(scope_, x); 308 break; 309 case CONJ: 310 y = Conj(scope_, x); 311 break; 312 } 313 314 std::vector<Output> grad_outputs; 315 TF_ASSERT_OK(test::CallGradFunction( 316 scope_, Operation(y.node()), {ops::Const(scope_, dy)}, &grad_outputs)); 317 Tensor dx; 318 test::GetTensor(scope_, grad_outputs[0], &dx); 319 test::ExpectClose(dx, dx_expected); 320 } 321 322 Scope scope_; 323}; 324 325TEST_F(CWiseUnaryComplexGradTest, Real) { 326 Tensor x = test::AsTensor<complex64>( 327 {{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3}); 328 Tensor dy = test::AsTensor<float>({11, -12, 13, -14, 15, -16}, {2, 3}); 329 Tensor dx_expected = test::AsTensor<complex64>( 330 {{11, 0}, {-12, 0}, {13, 0}, {-14, 0}, {15, 0}, {-16, 0}}, {2, 3}); 331 TestCWiseGradComplex(REAL, x, dy, dx_expected); 332} 333 334TEST_F(CWiseUnaryComplexGradTest, Imag) { 335 Tensor x = test::AsTensor<complex64>( 336 {{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3}); 337 Tensor dy = test::AsTensor<float>({11, -12, 13, -14, 15, -16}, {2, 3}); 338 Tensor dx_expected = test::AsTensor<complex64>( 339 {{0, 11}, {0, -12}, {0, 13}, {0, -14}, {0, 15}, {0, -16}}, {2, 3}); 340 TestCWiseGradComplex(IMAG, x, dy, dx_expected); 341} 342 343TEST_F(CWiseUnaryComplexGradTest, Conj) { 344 Tensor x = test::AsTensor<complex64>( 345 {{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3}); 346 Tensor dy = test::AsTensor<complex64>( 347 {{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3}); 348 Tensor dx_expected = test::AsTensor<complex64>( 349 {{1, 1}, {-2, -2}, {3, 3}, {-4, -4}, {8, 8}, {-9, -9}}, {2, 3}); 350 TestCWiseGradComplex(CONJ, x, dy, dx_expected); 351} 352 353class MathGradTest : public ::testing::Test { 354 protected: 355 MathGradTest() : root_(Scope::NewRootScope().WithDevice("/cpu:0")) {} 356 357 void TestMatMulGrad(const bool is_batch, const bool t_x, const bool t_y) { 358 // Generate random test data. 359 std::vector<Tensor> data; 360 RandMatMulGradData(is_batch, t_x, t_y, &data); 361 auto x = Const(root_, data[0]); 362 auto y = Const(root_, data[1]); 363 auto dz = Const(root_, data[2]); 364 365 std::vector<Tensor> grad_outputs; 366 ComputeMatMulGrad(is_batch, x, t_x, y, t_y, dz, &grad_outputs); 367 368 if (!t_x && !t_y) { 369 test::ExpectClose(grad_outputs[0], 370 ComputeMatMul(is_batch, dz, false, y, true)); 371 test::ExpectClose(grad_outputs[1], 372 ComputeMatMul(is_batch, x, true, dz, false)); 373 } else if (t_x && !t_y) { 374 test::ExpectClose(grad_outputs[0], 375 ComputeMatMul(is_batch, y, false, dz, true)); 376 test::ExpectClose(grad_outputs[1], 377 ComputeMatMul(is_batch, x, false, dz, false)); 378 } else if (!t_x && t_y) { 379 test::ExpectClose(grad_outputs[0], 380 ComputeMatMul(is_batch, dz, false, y, false)); 381 test::ExpectClose(grad_outputs[1], 382 ComputeMatMul(is_batch, dz, true, x, false)); 383 } else { 384 test::ExpectClose(grad_outputs[0], 385 ComputeMatMul(is_batch, y, true, dz, true)); 386 test::ExpectClose(grad_outputs[1], 387 ComputeMatMul(is_batch, dz, true, x, true)); 388 } 389 } 390 391 void ComputeMatMulGrad(const bool is_batch, const Output& x, const bool t_x, 392 const Output& y, const bool t_y, const Output& dz, 393 std::vector<Tensor>* out) { 394 // Compute forward MatMul: z = MatMul(x, y). 395 Output z; 396 if (is_batch) { 397 z = BatchMatMul(root_, x, y, BatchMatMul::AdjX(t_x).AdjY(t_y)); 398 } else { 399 z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y)); 400 } 401 TF_ASSERT_OK(root_.status()); 402 CHECK_NOTNULL(z.node()); 403 std::vector<Output> grad_outputs; 404 // Call MatMulGrad which populates 'grad_outputs'. 405 TF_ASSERT_OK(test::CallGradFunction(root_, Operation(z.node()), {dz}, 406 &grad_outputs)); 407 ASSERT_EQ(2, grad_outputs.size()); 408 // Run graph and return MatMul gradient tensors for 'dx' and 'dy' in 'out'. 409 test::GetTensors(root_, {grad_outputs[0], grad_outputs[1]}, out); 410 } 411 412 Tensor ComputeMatMul(const bool is_batch, const Output& x, const bool t_x, 413 const Output& y, const bool t_y) { 414 Output z; 415 if (is_batch) { 416 z = BatchMatMul(root_, x, y, BatchMatMul::AdjX(t_x).AdjY(t_y)); 417 } else { 418 z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y)); 419 } 420 TF_EXPECT_OK(root_.status()); 421 Tensor out; 422 test::GetTensor(root_, z, &out); 423 return out; 424 } 425 426 void RandMatMulGradData(const bool is_batch, const bool tx, const bool ty, 427 std::vector<Tensor>* data) { 428 // Choose a random batch size in [1, 4] 429 const int b = 1 + (random::New64() % 4); 430 // z = MatMul(x, y) 431 const int m = Rand(); 432 const int k = Rand(); 433 const int n = Rand(); 434 435 TensorShape x_shape; 436 if (is_batch) { 437 // x.shape = [b, m, k] 438 x_shape = tx ? TensorShape({b, k, m}) : TensorShape({b, m, k}); 439 } else { 440 // x.shape = [m, k] 441 x_shape = tx ? TensorShape({k, m}) : TensorShape({m, k}); 442 } 443 data->emplace_back(DT_FLOAT, x_shape); 444 RandTensor(&data->back()); 445 446 TensorShape y_shape; 447 if (is_batch) { 448 // y.shape = [b, k, n] 449 y_shape = ty ? TensorShape({b, n, k}) : TensorShape({b, k, n}); 450 } else { 451 // y.shape = [k, n] 452 y_shape = ty ? TensorShape({n, k}) : TensorShape({k, n}); 453 } 454 data->emplace_back(DT_FLOAT, y_shape); 455 RandTensor(&data->back()); 456 457 TensorShape z_shape; 458 if (is_batch) { 459 // z.shape = [b, m, n] 460 z_shape = TensorShape({b, m, n}); 461 } else { 462 // z.shape = [m, n] 463 z_shape = TensorShape({m, n}); 464 } 465 data->emplace_back(DT_FLOAT, z_shape); 466 RandTensor(&data->back()); 467 } 468 469 void RandTensor(Tensor* t) { 470 test::FillFn<float>( 471 t, [this](const int i) { return static_cast<float>(Rand()); }); 472 } 473 474 int Rand() { return 1 + (random::New64() % 10); } 475 476 Scope root_; 477}; 478 479TEST_F(MathGradTest, MatMulGrad_NoTranspose) { 480 TestMatMulGrad(false, false, false); 481} 482 483TEST_F(MathGradTest, MatMulGrad_TransposeX) { 484 TestMatMulGrad(false, true, false); 485} 486 487TEST_F(MathGradTest, MatMulGrad_TransposeY) { 488 TestMatMulGrad(false, false, true); 489} 490 491TEST_F(MathGradTest, MatMulGrad_TransposeX_TransposeY) { 492 TestMatMulGrad(false, true, true); 493} 494 495TEST_F(MathGradTest, BatchMatMulGrad_NoTranspose) { 496 TestMatMulGrad(true, false, false); 497} 498 499TEST_F(MathGradTest, BatchMatMulGrad_TransposeX) { 500 TestMatMulGrad(true, true, false); 501} 502 503TEST_F(MathGradTest, BatchMatMulGrad_TransposeY) { 504 TestMatMulGrad(true, false, true); 505} 506 507TEST_F(MathGradTest, BatchMatMulGrad_TransposeX_TransposeY) { 508 TestMatMulGrad(true, true, true); 509} 510 511} // namespace 512} // namespace tensorflow 513