math_grad_test.cc revision 1d0b8c007b8bc7f77dd63c74f02d87185071f038
1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/cc/framework/grad_op_registry.h"
17#include "tensorflow/cc/framework/testutil.h"
18#include "tensorflow/cc/gradients/grad_testutil.h"
19#include "tensorflow/cc/ops/standard_ops.h"
20#include "tensorflow/core/framework/tensor_testutil.h"
21#include "tensorflow/core/lib/core/status_test_util.h"
22#include "tensorflow/core/lib/random/random.h"
23
24namespace tensorflow {
25using namespace ops;  // NOLINT(build/namespaces)
26
27namespace {
28
29// TODO(andydavis) Test gradient function against numeric gradients output.
30// TODO(andydavis) As more gradients are added move common test functions
31// to a testutil library.
32
33class CWiseUnaryGradTest : public ::testing::Test {
34 protected:
35  CWiseUnaryGradTest() : scope_(Scope::NewRootScope().WithDevice("/cpu:0")) {}
36
37  enum UnaryOpType {
38    ABS,
39    NEG,
40    INV,
41    SQUARE,
42    SQRT,
43    RSQRT,
44    EXP,
45    EXPM1,
46    LOG,
47    LOG1P,
48    TANH,
49    SIGMOID,
50    SIGN,
51    SIN,
52    COS,
53    ASIN,
54    ACOS,
55    TAN,
56    ATAN
57  };
58
59  void TestCWiseGrad(UnaryOpType op_type, const std::function<float(int)>& x_fn,
60                     const std::function<float(float)>& dy_fn,
61                     const std::function<float(float, float)>& dx_fn) {
62    Tensor x(DT_FLOAT, {2, 3, 2});
63    auto x_flat = x.flat<float>();
64    for (int i = 0; i < x_flat.size(); ++i) {
65      x_flat(i) = x_fn(i);
66    }
67
68    Tensor dy(DT_FLOAT, {2, 3, 2});
69    auto dy_flat = dy.flat<float>();
70    for (int i = 0; i < dy_flat.size(); ++i) {
71      dy_flat(i) = dy_fn(x_flat(i));
72    }
73
74    Tensor dx(DT_FLOAT, {2, 3, 2});
75    auto dx_flat = dx.flat<float>();
76    for (int i = 0; i < dx_flat.size(); ++i) {
77      dx_flat(i) = dx_fn(x_flat(i), dy_flat(i));
78    }
79
80    Output y;
81    switch (op_type) {
82      case ABS:
83        y = Abs(scope_, x);
84        break;
85      case NEG:
86        y = Neg(scope_, x);
87        break;
88      case INV:
89        y = Reciprocal(scope_, x);
90        break;
91      case SQUARE:
92        y = Square(scope_, x);
93        break;
94      case SQRT:
95        y = Sqrt(scope_, x);
96        break;
97      case RSQRT:
98        y = Rsqrt(scope_, x);
99        break;
100      case EXP:
101        y = Exp(scope_, x);
102        break;
103      case EXPM1:
104        y = Expm1(scope_, x);
105        break;
106      case LOG:
107        y = Log(scope_, x);
108        break;
109      case LOG1P:
110        y = Log1p(scope_, x);
111        break;
112      case TANH:
113        y = Tanh(scope_, x);
114        break;
115      case SIGMOID:
116        y = Sigmoid(scope_, x);
117        break;
118      case SIGN:
119        y = Sign(scope_, x);
120        break;
121      case SIN:
122        y = Sin(scope_, x);
123        break;
124      case COS:
125        y = Cos(scope_, x);
126        break;
127      case ASIN:
128        y = Asin(scope_, x);
129        break;
130      case ACOS:
131        y = Acos(scope_, x);
132        break;
133      case TAN:
134        y = Tan(scope_, x);
135        break;
136      case ATAN:
137        y = Atan(scope_, x);
138        break;
139    }
140
141    std::vector<Output> grad_outputs;
142    TF_ASSERT_OK(test::CallGradFunction(
143        scope_, Operation(y.node()), {ops::Const(scope_, dy)}, &grad_outputs));
144    Tensor output;
145    test::GetTensor(scope_, grad_outputs[0], &output);
146    test::ExpectClose(output, dx);
147  }
148
149  float RV(std::vector<float> v) { return v[random::New64() % v.size()]; }
150
151  Scope scope_;
152};
153
154TEST_F(CWiseUnaryGradTest, Abs) {
155  auto x_fn = [this](const int i) { return RV({-1, 0, 1}); };
156  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
157  auto dx_fn = [this](const float x, const float dy) { return x * dy; };
158  TestCWiseGrad(ABS, x_fn, dy_fn, dx_fn);
159}
160
161TEST_F(CWiseUnaryGradTest, Neg) {
162  auto x_fn = [this](const int i) { return RV({-1, 0, 1}); };
163  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
164  auto dx_fn = [this](const float x, const float dy) { return -dy; };
165  TestCWiseGrad(NEG, x_fn, dy_fn, dx_fn);
166}
167
168TEST_F(CWiseUnaryGradTest, Reciprocal) {
169  auto x_fn = [this](const int i) { return RV({-1, 1, -2, 2, -3, 3, -4, 4}); };
170  auto dy_fn = [this](const float x) { return RV({0, -2, 2, -3, 3, -4, 4}); };
171  auto dx_fn = [this](const float x, const float dy) {
172    return -(1 / (x * x)) * dy;
173  };
174  TestCWiseGrad(INV, x_fn, dy_fn, dx_fn);
175}
176
177TEST_F(CWiseUnaryGradTest, Square) {
178  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
179  auto dy_fn = [this](const float x) { return RV({0, -7, 7, -8, 8, -9, 9}); };
180  auto dx_fn = [this](const float x, const float dy) { return 2 * x * dy; };
181  TestCWiseGrad(SQUARE, x_fn, dy_fn, dx_fn);
182}
183
184TEST_F(CWiseUnaryGradTest, Sqrt) {
185  auto x_fn = [this](const int i) { return RV({0, 1, 2, 3, 4, 5, 6, 7}); };
186  auto dy_fn = [this](const float x) { return x + RV({8, 9, 10, 11, 12, 13}); };
187  auto dx_fn = [this](const float x, const float dy) {
188    return dy * 0.5 * (1.0 / std::sqrt(x));
189  };
190  TestCWiseGrad(SQRT, x_fn, dy_fn, dx_fn);
191}
192
193TEST_F(CWiseUnaryGradTest, Rsqrt) {
194  auto x_fn = [this](const int i) { return RV({1, 2, 3, 4, 5, 6, 7, 8}); };
195  auto dy_fn = [this](const float x) { return x + RV({8, 9, 10, 11, 12, 13}); };
196  auto dx_fn = [this](const float x, const float dy) {
197    return dy * -0.5 * (1 / std::sqrt(x)) * (1 / x);
198  };
199  TestCWiseGrad(RSQRT, x_fn, dy_fn, dx_fn);
200}
201
202TEST_F(CWiseUnaryGradTest, Exp) {
203  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
204  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
205  auto dx_fn = [this](const float x, const float dy) {
206    return dy * std::exp(x);
207  };
208  TestCWiseGrad(EXP, x_fn, dy_fn, dx_fn);
209}
210
211TEST_F(CWiseUnaryGradTest, Expm1) {
212  auto x_fn = [this](const int i) { return RV({0, -1, 1e-6, 1, -2, 3, 100}); };
213  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
214  auto dx_fn = [this](const float x, const float dy) {
215    return dy * std::exp(x);
216  };
217  TestCWiseGrad(EXPM1, x_fn, dy_fn, dx_fn);
218}
219
220TEST_F(CWiseUnaryGradTest, Log) {
221  auto x_fn = [this](const int i) { return RV({-1, 1, -2, 2, -3, 3, -4, 4}); };
222  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
223  auto dx_fn = [this](const float x, const float dy) { return dy * (1.0 / x); };
224  TestCWiseGrad(LOG, x_fn, dy_fn, dx_fn);
225}
226
227TEST_F(CWiseUnaryGradTest, Log1p) {
228  auto x_fn = [this](const int i) { return RV({0, 1e-6, 1, 2, 3, 4, 100}); };
229  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
230  auto dx_fn = [this](const float x, const float dy) {
231    return dy * (1.0 / (1.0 + x));
232  };
233  TestCWiseGrad(LOG1P, x_fn, dy_fn, dx_fn);
234}
235
236TEST_F(CWiseUnaryGradTest, Tanh) {
237  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
238  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
239  auto dx_fn = [this](const float x, const float dy) {
240    const float y = std::tanh(x);
241    return dy * (1.0 - y * y);
242  };
243  TestCWiseGrad(TANH, x_fn, dy_fn, dx_fn);
244}
245
246TEST_F(CWiseUnaryGradTest, Sigmoid) {
247  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
248  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
249  auto dx_fn = [this](const float x, const float dy) {
250    const float y = 1.0 / (1.0 + std::exp(-x));
251    return dy * y * (1.0 - y);
252  };
253  TestCWiseGrad(SIGMOID, x_fn, dy_fn, dx_fn);
254}
255
256TEST_F(CWiseUnaryGradTest, Sign) {
257  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
258  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
259  auto dx_fn = [this](const float x, const float dy) { return 0.0; };
260  TestCWiseGrad(SIGN, x_fn, dy_fn, dx_fn);
261}
262
263TEST_F(CWiseUnaryGradTest, Sin) {
264  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
265  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
266  auto dx_fn = [this](const float x, const float dy) {
267    return dy * std::cos(x);
268  };
269  TestCWiseGrad(SIN, x_fn, dy_fn, dx_fn);
270}
271
272TEST_F(CWiseUnaryGradTest, Cos) {
273  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
274  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
275  auto dx_fn = [this](const float x, const float dy) {
276    return dy * -1.0 * std::sin(x);
277  };
278  TestCWiseGrad(COS, x_fn, dy_fn, dx_fn);
279}
280
281TEST_F(CWiseUnaryGradTest, Asin) {
282  auto x_fn = [this](const int i) { return RV({0, -0.5, 0.5, -1, 1}); };
283  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
284  auto dx_fn = [this](const float x, const float dy) {
285    return dy * (1.0 / std::sqrt(1.0 - x * x));
286  };
287  TestCWiseGrad(ASIN, x_fn, dy_fn, dx_fn);
288}
289
290TEST_F(CWiseUnaryGradTest, Acos) {
291  auto x_fn = [this](const int i) { return RV({0, -0.5, 0.5, -1, 1}); };
292  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
293  auto dx_fn = [this](const float x, const float dy) {
294    return dy * (-1.0 / std::sqrt(1.0 - x * x));
295  };
296  TestCWiseGrad(ACOS, x_fn, dy_fn, dx_fn);
297}
298
299TEST_F(CWiseUnaryGradTest, Tan) {
300  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
301  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
302  auto dx_fn = [this](const float x, const float dy) {
303    const float cosx = std::cos(x);
304    return dy * (1 / (cosx * cosx));
305  };
306  TestCWiseGrad(TAN, x_fn, dy_fn, dx_fn);
307}
308
309TEST_F(CWiseUnaryGradTest, Atan) {
310  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
311  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
312  auto dx_fn = [this](const float x, const float dy) {
313    return dy * (1 / (1 + x * x));
314  };
315  TestCWiseGrad(ATAN, x_fn, dy_fn, dx_fn);
316}
317
318class CWiseUnaryComplexGradTest : public ::testing::Test {
319 protected:
320  CWiseUnaryComplexGradTest()
321      : scope_(Scope::NewRootScope().WithDevice("/cpu:0")) {}
322
323  enum UnaryOpType { REAL, IMAG, CONJ };
324
325  void TestCWiseGradComplex(UnaryOpType op_type, const Tensor& x,
326                            const Tensor& dy, const Tensor& dx_expected) {
327    Output y;
328    switch (op_type) {
329      case REAL:
330        y = Real(scope_, x);
331        break;
332      case IMAG:
333        y = Imag(scope_, x);
334        break;
335      case CONJ:
336        y = Conj(scope_, x);
337        break;
338    }
339
340    std::vector<Output> grad_outputs;
341    TF_ASSERT_OK(test::CallGradFunction(
342        scope_, Operation(y.node()), {ops::Const(scope_, dy)}, &grad_outputs));
343    Tensor dx;
344    test::GetTensor(scope_, grad_outputs[0], &dx);
345    test::ExpectClose(dx, dx_expected);
346  }
347
348  Scope scope_;
349};
350
351TEST_F(CWiseUnaryComplexGradTest, Real) {
352  Tensor x = test::AsTensor<complex64>(
353      {{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3});
354  Tensor dy = test::AsTensor<float>({11, -12, 13, -14, 15, -16}, {2, 3});
355  Tensor dx_expected = test::AsTensor<complex64>(
356      {{11, 0}, {-12, 0}, {13, 0}, {-14, 0}, {15, 0}, {-16, 0}}, {2, 3});
357  TestCWiseGradComplex(REAL, x, dy, dx_expected);
358}
359
360TEST_F(CWiseUnaryComplexGradTest, Imag) {
361  Tensor x = test::AsTensor<complex64>(
362      {{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3});
363  Tensor dy = test::AsTensor<float>({11, -12, 13, -14, 15, -16}, {2, 3});
364  Tensor dx_expected = test::AsTensor<complex64>(
365      {{0, 11}, {0, -12}, {0, 13}, {0, -14}, {0, 15}, {0, -16}}, {2, 3});
366  TestCWiseGradComplex(IMAG, x, dy, dx_expected);
367}
368
369TEST_F(CWiseUnaryComplexGradTest, Conj) {
370  Tensor x = test::AsTensor<complex64>(
371      {{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3});
372  Tensor dy = test::AsTensor<complex64>(
373      {{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3});
374  Tensor dx_expected = test::AsTensor<complex64>(
375      {{1, 1}, {-2, -2}, {3, 3}, {-4, -4}, {8, 8}, {-9, -9}}, {2, 3});
376  TestCWiseGradComplex(CONJ, x, dy, dx_expected);
377}
378
379class MathGradTest : public ::testing::Test {
380 protected:
381  MathGradTest() : root_(Scope::NewRootScope().WithDevice("/cpu:0")) {}
382
383  void TestMatMulGrad(const bool is_batch, const bool t_x, const bool t_y) {
384    // Generate random test data.
385    std::vector<Tensor> data;
386    RandMatMulGradData(is_batch, t_x, t_y, &data);
387    auto x = Const(root_, data[0]);
388    auto y = Const(root_, data[1]);
389    auto dz = Const(root_, data[2]);
390
391    std::vector<Tensor> grad_outputs;
392    ComputeMatMulGrad(is_batch, x, t_x, y, t_y, dz, &grad_outputs);
393
394    if (!t_x && !t_y) {
395      test::ExpectClose(grad_outputs[0],
396                        ComputeMatMul(is_batch, dz, false, y, true));
397      test::ExpectClose(grad_outputs[1],
398                        ComputeMatMul(is_batch, x, true, dz, false));
399    } else if (t_x && !t_y) {
400      test::ExpectClose(grad_outputs[0],
401                        ComputeMatMul(is_batch, y, false, dz, true));
402      test::ExpectClose(grad_outputs[1],
403                        ComputeMatMul(is_batch, x, false, dz, false));
404    } else if (!t_x && t_y) {
405      test::ExpectClose(grad_outputs[0],
406                        ComputeMatMul(is_batch, dz, false, y, false));
407      test::ExpectClose(grad_outputs[1],
408                        ComputeMatMul(is_batch, dz, true, x, false));
409    } else {
410      test::ExpectClose(grad_outputs[0],
411                        ComputeMatMul(is_batch, y, true, dz, true));
412      test::ExpectClose(grad_outputs[1],
413                        ComputeMatMul(is_batch, dz, true, x, true));
414    }
415  }
416
417  void ComputeMatMulGrad(const bool is_batch, const Output& x, const bool t_x,
418                         const Output& y, const bool t_y, const Output& dz,
419                         std::vector<Tensor>* out) {
420    // Compute forward MatMul: z = MatMul(x, y).
421    Output z;
422    if (is_batch) {
423      z = BatchMatMul(root_, x, y, BatchMatMul::AdjX(t_x).AdjY(t_y));
424    } else {
425      z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y));
426    }
427    TF_ASSERT_OK(root_.status());
428    CHECK_NOTNULL(z.node());
429    std::vector<Output> grad_outputs;
430    // Call MatMulGrad which populates 'grad_outputs'.
431    TF_ASSERT_OK(test::CallGradFunction(root_, Operation(z.node()), {dz},
432                                        &grad_outputs));
433    ASSERT_EQ(2, grad_outputs.size());
434    // Run graph and return MatMul gradient tensors for 'dx' and 'dy' in 'out'.
435    test::GetTensors(root_, {grad_outputs[0], grad_outputs[1]}, out);
436  }
437
438  Tensor ComputeMatMul(const bool is_batch, const Output& x, const bool t_x,
439                       const Output& y, const bool t_y) {
440    Output z;
441    if (is_batch) {
442      z = BatchMatMul(root_, x, y, BatchMatMul::AdjX(t_x).AdjY(t_y));
443    } else {
444      z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y));
445    }
446    TF_EXPECT_OK(root_.status());
447    Tensor out;
448    test::GetTensor(root_, z, &out);
449    return out;
450  }
451
452  void RandMatMulGradData(const bool is_batch, const bool tx, const bool ty,
453                          std::vector<Tensor>* data) {
454    // Choose a random batch size in [1, 4]
455    const int b = 1 + (random::New64() % 4);
456    // z = MatMul(x, y)
457    const int m = Rand();
458    const int k = Rand();
459    const int n = Rand();
460
461    TensorShape x_shape;
462    if (is_batch) {
463      // x.shape = [b, m, k]
464      x_shape = tx ? TensorShape({b, k, m}) : TensorShape({b, m, k});
465    } else {
466      // x.shape = [m, k]
467      x_shape = tx ? TensorShape({k, m}) : TensorShape({m, k});
468    }
469    data->emplace_back(DT_FLOAT, x_shape);
470    RandTensor(&data->back());
471
472    TensorShape y_shape;
473    if (is_batch) {
474      // y.shape = [b, k, n]
475      y_shape = ty ? TensorShape({b, n, k}) : TensorShape({b, k, n});
476    } else {
477      // y.shape = [k, n]
478      y_shape = ty ? TensorShape({n, k}) : TensorShape({k, n});
479    }
480    data->emplace_back(DT_FLOAT, y_shape);
481    RandTensor(&data->back());
482
483    TensorShape z_shape;
484    if (is_batch) {
485      // z.shape = [b, m, n]
486      z_shape = TensorShape({b, m, n});
487    } else {
488      // z.shape = [m, n]
489      z_shape = TensorShape({m, n});
490    }
491    data->emplace_back(DT_FLOAT, z_shape);
492    RandTensor(&data->back());
493  }
494
495  void RandTensor(Tensor* t) {
496    test::FillFn<float>(
497        t, [this](const int i) { return static_cast<float>(Rand()); });
498  }
499
500  int Rand() { return 1 + (random::New64() % 10); }
501
502  Scope root_;
503};
504
505TEST_F(MathGradTest, MatMulGrad_NoTranspose) {
506  TestMatMulGrad(false, false, false);
507}
508
509TEST_F(MathGradTest, MatMulGrad_TransposeX) {
510  TestMatMulGrad(false, true, false);
511}
512
513TEST_F(MathGradTest, MatMulGrad_TransposeY) {
514  TestMatMulGrad(false, false, true);
515}
516
517TEST_F(MathGradTest, MatMulGrad_TransposeX_TransposeY) {
518  TestMatMulGrad(false, true, true);
519}
520
521TEST_F(MathGradTest, BatchMatMulGrad_NoTranspose) {
522  TestMatMulGrad(true, false, false);
523}
524
525TEST_F(MathGradTest, BatchMatMulGrad_TransposeX) {
526  TestMatMulGrad(true, true, false);
527}
528
529TEST_F(MathGradTest, BatchMatMulGrad_TransposeY) {
530  TestMatMulGrad(true, false, true);
531}
532
533TEST_F(MathGradTest, BatchMatMulGrad_TransposeX_TransposeY) {
534  TestMatMulGrad(true, true, true);
535}
536
537}  // namespace
538}  // namespace tensorflow
539