math_grad_test.cc revision fb01ebb8c38b2d274f6fe9a7115b2362828a452e
1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/cc/framework/grad_op_registry.h"
17#include "tensorflow/cc/framework/testutil.h"
18#include "tensorflow/cc/gradients/grad_testutil.h"
19#include "tensorflow/cc/ops/standard_ops.h"
20#include "tensorflow/core/framework/tensor_testutil.h"
21#include "tensorflow/core/lib/core/status_test_util.h"
22#include "tensorflow/core/lib/random/random.h"
23
24namespace tensorflow {
25using namespace ops;  // NOLINT(build/namespaces)
26
27namespace {
28
29// TODO(andydavis) Test gradient function against numeric gradients output.
30// TODO(andydavis) As more gradients are added move common test functions
31// to a testutil library.
32
33class CWiseUnaryGradTest : public ::testing::Test {
34 protected:
35  CWiseUnaryGradTest() : scope_(Scope::NewRootScope().WithDevice("/cpu:0")) {}
36
37  enum UnaryOpType {
38    ABS,
39    NEG,
40    INV,
41    SQUARE,
42    SQRT,
43    RSQRT,
44    EXP,
45    LOG,
46    TANH,
47    SIGMOID,
48    SIGN,
49    SIN,
50    COS,
51    ASIN,
52    ACOS,
53    TAN,
54    ATAN
55  };
56
57  void TestCWiseGrad(UnaryOpType op_type, std::function<float(int)> x_fn,
58                     std::function<float(float)> dy_fn,
59                     std::function<float(float, float)> dx_fn) {
60    Tensor x(DT_FLOAT, {2, 3, 2});
61    auto x_flat = x.flat<float>();
62    for (int i = 0; i < x_flat.size(); ++i) {
63      x_flat(i) = x_fn(i);
64    }
65
66    Tensor dy(DT_FLOAT, {2, 3, 2});
67    auto dy_flat = dy.flat<float>();
68    for (int i = 0; i < dy_flat.size(); ++i) {
69      dy_flat(i) = dy_fn(x_flat(i));
70    }
71
72    Tensor dx(DT_FLOAT, {2, 3, 2});
73    auto dx_flat = dx.flat<float>();
74    for (int i = 0; i < dx_flat.size(); ++i) {
75      dx_flat(i) = dx_fn(x_flat(i), dy_flat(i));
76    }
77
78    Output y;
79    switch (op_type) {
80      case ABS:
81        y = Abs(scope_, x);
82        break;
83      case NEG:
84        y = Neg(scope_, x);
85        break;
86      case INV:
87        y = Reciprocal(scope_, x);
88        break;
89      case SQUARE:
90        y = Square(scope_, x);
91        break;
92      case SQRT:
93        y = Sqrt(scope_, x);
94        break;
95      case RSQRT:
96        y = Rsqrt(scope_, x);
97        break;
98      case EXP:
99        y = Exp(scope_, x);
100        break;
101      case LOG:
102        y = Log(scope_, x);
103        break;
104      case TANH:
105        y = Tanh(scope_, x);
106        break;
107      case SIGMOID:
108        y = Sigmoid(scope_, x);
109        break;
110      case SIGN:
111        y = Sign(scope_, x);
112        break;
113      case SIN:
114        y = Sin(scope_, x);
115        break;
116      case COS:
117        y = Cos(scope_, x);
118        break;
119      case ASIN:
120        y = Asin(scope_, x);
121        break;
122      case ACOS:
123        y = Acos(scope_, x);
124        break;
125      case TAN:
126        y = Tan(scope_, x);
127        break;
128      case ATAN:
129        y = Atan(scope_, x);
130        break;
131    }
132
133    std::vector<Output> grad_outputs;
134    TF_ASSERT_OK(test::CallGradFunction(
135        scope_, Operation(y.node()), {ops::Const(scope_, dy)}, &grad_outputs));
136    Tensor output;
137    test::GetTensor(scope_, grad_outputs[0], &output);
138    test::ExpectClose(output, dx);
139  }
140
141  float RV(std::vector<float> v) { return v[random::New64() % v.size()]; }
142
143  Scope scope_;
144};
145
146TEST_F(CWiseUnaryGradTest, Abs) {
147  auto x_fn = [this](const int i) { return RV({-1, 0, 1}); };
148  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
149  auto dx_fn = [this](const float x, const float dy) { return x * dy; };
150  TestCWiseGrad(ABS, x_fn, dy_fn, dx_fn);
151}
152
153TEST_F(CWiseUnaryGradTest, Neg) {
154  auto x_fn = [this](const int i) { return RV({-1, 0, 1}); };
155  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
156  auto dx_fn = [this](const float x, const float dy) { return -dy; };
157  TestCWiseGrad(NEG, x_fn, dy_fn, dx_fn);
158}
159
160TEST_F(CWiseUnaryGradTest, Reciprocal) {
161  auto x_fn = [this](const int i) { return RV({-1, 1, -2, 2, -3, 3, -4, 4}); };
162  auto dy_fn = [this](const float x) { return RV({0, -2, 2, -3, 3, -4, 4}); };
163  auto dx_fn = [this](const float x, const float dy) {
164    return -(1 / (x * x)) * dy;
165  };
166  TestCWiseGrad(INV, x_fn, dy_fn, dx_fn);
167}
168
169TEST_F(CWiseUnaryGradTest, Square) {
170  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
171  auto dy_fn = [this](const float x) { return RV({0, -7, 7, -8, 8, -9, 9}); };
172  auto dx_fn = [this](const float x, const float dy) { return 2 * x * dy; };
173  TestCWiseGrad(SQUARE, x_fn, dy_fn, dx_fn);
174}
175
176TEST_F(CWiseUnaryGradTest, Sqrt) {
177  auto x_fn = [this](const int i) { return RV({0, 1, 2, 3, 4, 5, 6, 7}); };
178  auto dy_fn = [this](const float x) { return x + RV({8, 9, 10, 11, 12, 13}); };
179  auto dx_fn = [this](const float x, const float dy) {
180    return dy * 0.5 * (1.0 / std::sqrt(x));
181  };
182  TestCWiseGrad(SQRT, x_fn, dy_fn, dx_fn);
183}
184
185TEST_F(CWiseUnaryGradTest, Rsqrt) {
186  auto x_fn = [this](const int i) { return RV({1, 2, 3, 4, 5, 6, 7, 8}); };
187  auto dy_fn = [this](const float x) { return x + RV({8, 9, 10, 11, 12, 13}); };
188  auto dx_fn = [this](const float x, const float dy) {
189    return dy * -0.5 * (1 / std::sqrt(x)) * (1 / x);
190  };
191  TestCWiseGrad(RSQRT, x_fn, dy_fn, dx_fn);
192}
193
194TEST_F(CWiseUnaryGradTest, Exp) {
195  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
196  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
197  auto dx_fn = [this](const float x, const float dy) {
198    return dy * std::exp(x);
199  };
200  TestCWiseGrad(EXP, x_fn, dy_fn, dx_fn);
201}
202
203TEST_F(CWiseUnaryGradTest, Log) {
204  auto x_fn = [this](const int i) { return RV({-1, 1, -2, 2, -3, 3, -4, 4}); };
205  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
206  auto dx_fn = [this](const float x, const float dy) { return dy * (1.0 / x); };
207  TestCWiseGrad(LOG, x_fn, dy_fn, dx_fn);
208}
209
210TEST_F(CWiseUnaryGradTest, Tanh) {
211  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
212  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
213  auto dx_fn = [this](const float x, const float dy) {
214    const float y = std::tanh(x);
215    return dy * (1.0 - y * y);
216  };
217  TestCWiseGrad(TANH, x_fn, dy_fn, dx_fn);
218}
219
220TEST_F(CWiseUnaryGradTest, Sigmoid) {
221  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
222  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
223  auto dx_fn = [this](const float x, const float dy) {
224    const float y = 1.0 / (1.0 + std::exp(-x));
225    return dy * y * (1.0 - y);
226  };
227  TestCWiseGrad(SIGMOID, x_fn, dy_fn, dx_fn);
228}
229
230TEST_F(CWiseUnaryGradTest, Sign) {
231  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
232  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
233  auto dx_fn = [this](const float x, const float dy) { return 0.0; };
234  TestCWiseGrad(SIGN, x_fn, dy_fn, dx_fn);
235}
236
237TEST_F(CWiseUnaryGradTest, Sin) {
238  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
239  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
240  auto dx_fn = [this](const float x, const float dy) {
241    return dy * std::cos(x);
242  };
243  TestCWiseGrad(SIN, x_fn, dy_fn, dx_fn);
244}
245
246TEST_F(CWiseUnaryGradTest, Cos) {
247  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
248  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
249  auto dx_fn = [this](const float x, const float dy) {
250    return dy * -1.0 * std::sin(x);
251  };
252  TestCWiseGrad(COS, x_fn, dy_fn, dx_fn);
253}
254
255TEST_F(CWiseUnaryGradTest, Asin) {
256  auto x_fn = [this](const int i) { return RV({0, -0.5, 0.5, -1, 1}); };
257  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
258  auto dx_fn = [this](const float x, const float dy) {
259    return dy * (1.0 / std::sqrt(1.0 - x * x));
260  };
261  TestCWiseGrad(ASIN, x_fn, dy_fn, dx_fn);
262}
263
264TEST_F(CWiseUnaryGradTest, Acos) {
265  auto x_fn = [this](const int i) { return RV({0, -0.5, 0.5, -1, 1}); };
266  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
267  auto dx_fn = [this](const float x, const float dy) {
268    return dy * (-1.0 / std::sqrt(1.0 - x * x));
269  };
270  TestCWiseGrad(ACOS, x_fn, dy_fn, dx_fn);
271}
272
273TEST_F(CWiseUnaryGradTest, Tan) {
274  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
275  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
276  auto dx_fn = [this](const float x, const float dy) {
277    const float cosx = std::cos(x);
278    return dy * (1 / (cosx * cosx));
279  };
280  TestCWiseGrad(TAN, x_fn, dy_fn, dx_fn);
281}
282
283TEST_F(CWiseUnaryGradTest, Atan) {
284  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
285  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
286  auto dx_fn = [this](const float x, const float dy) {
287    return dy * (1 / (1 + x * x));
288  };
289  TestCWiseGrad(ATAN, x_fn, dy_fn, dx_fn);
290}
291
292class CWiseUnaryComplexGradTest : public ::testing::Test {
293 protected:
294  CWiseUnaryComplexGradTest()
295      : scope_(Scope::NewRootScope().WithDevice("/cpu:0")) {}
296
297  enum UnaryOpType { REAL, IMAG, CONJ };
298
299  void TestCWiseGradComplex(UnaryOpType op_type, const Tensor& x,
300                            const Tensor& dy, const Tensor& dx_expected) {
301    Output y;
302    switch (op_type) {
303      case REAL:
304        y = Real(scope_, x);
305        break;
306      case IMAG:
307        y = Imag(scope_, x);
308        break;
309      case CONJ:
310        y = Conj(scope_, x);
311        break;
312    }
313
314    std::vector<Output> grad_outputs;
315    TF_ASSERT_OK(test::CallGradFunction(
316        scope_, Operation(y.node()), {ops::Const(scope_, dy)}, &grad_outputs));
317    Tensor dx;
318    test::GetTensor(scope_, grad_outputs[0], &dx);
319    test::ExpectClose(dx, dx_expected);
320  }
321
322  Scope scope_;
323};
324
325TEST_F(CWiseUnaryComplexGradTest, Real) {
326  Tensor x = test::AsTensor<complex64>(
327      {{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3});
328  Tensor dy = test::AsTensor<float>({11, -12, 13, -14, 15, -16}, {2, 3});
329  Tensor dx_expected = test::AsTensor<complex64>(
330      {{11, 0}, {-12, 0}, {13, 0}, {-14, 0}, {15, 0}, {-16, 0}}, {2, 3});
331  TestCWiseGradComplex(REAL, x, dy, dx_expected);
332}
333
334TEST_F(CWiseUnaryComplexGradTest, Imag) {
335  Tensor x = test::AsTensor<complex64>(
336      {{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3});
337  Tensor dy = test::AsTensor<float>({11, -12, 13, -14, 15, -16}, {2, 3});
338  Tensor dx_expected = test::AsTensor<complex64>(
339      {{0, 11}, {0, -12}, {0, 13}, {0, -14}, {0, 15}, {0, -16}}, {2, 3});
340  TestCWiseGradComplex(IMAG, x, dy, dx_expected);
341}
342
343TEST_F(CWiseUnaryComplexGradTest, Conj) {
344  Tensor x = test::AsTensor<complex64>(
345      {{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3});
346  Tensor dy = test::AsTensor<complex64>(
347      {{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3});
348  Tensor dx_expected = test::AsTensor<complex64>(
349      {{1, 1}, {-2, -2}, {3, 3}, {-4, -4}, {8, 8}, {-9, -9}}, {2, 3});
350  TestCWiseGradComplex(CONJ, x, dy, dx_expected);
351}
352
353class MathGradTest : public ::testing::Test {
354 protected:
355  MathGradTest() : root_(Scope::NewRootScope().WithDevice("/cpu:0")) {}
356
357  void TestMatMulGrad(const bool is_batch, const bool t_x, const bool t_y) {
358    // Generate random test data.
359    std::vector<Tensor> data;
360    RandMatMulGradData(is_batch, t_x, t_y, &data);
361    auto x = Const(root_, data[0]);
362    auto y = Const(root_, data[1]);
363    auto dz = Const(root_, data[2]);
364
365    std::vector<Tensor> grad_outputs;
366    ComputeMatMulGrad(is_batch, x, t_x, y, t_y, dz, &grad_outputs);
367
368    if (!t_x && !t_y) {
369      test::ExpectClose(grad_outputs[0],
370                        ComputeMatMul(is_batch, dz, false, y, true));
371      test::ExpectClose(grad_outputs[1],
372                        ComputeMatMul(is_batch, x, true, dz, false));
373    } else if (t_x && !t_y) {
374      test::ExpectClose(grad_outputs[0],
375                        ComputeMatMul(is_batch, y, false, dz, true));
376      test::ExpectClose(grad_outputs[1],
377                        ComputeMatMul(is_batch, x, false, dz, false));
378    } else if (!t_x && t_y) {
379      test::ExpectClose(grad_outputs[0],
380                        ComputeMatMul(is_batch, dz, false, y, false));
381      test::ExpectClose(grad_outputs[1],
382                        ComputeMatMul(is_batch, dz, true, x, false));
383    } else {
384      test::ExpectClose(grad_outputs[0],
385                        ComputeMatMul(is_batch, y, true, dz, true));
386      test::ExpectClose(grad_outputs[1],
387                        ComputeMatMul(is_batch, dz, true, x, true));
388    }
389  }
390
391  void ComputeMatMulGrad(const bool is_batch, const Output& x, const bool t_x,
392                         const Output& y, const bool t_y, const Output& dz,
393                         std::vector<Tensor>* out) {
394    // Compute forward MatMul: z = MatMul(x, y).
395    Output z;
396    if (is_batch) {
397      z = BatchMatMul(root_, x, y, BatchMatMul::AdjX(t_x).AdjY(t_y));
398    } else {
399      z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y));
400    }
401    TF_ASSERT_OK(root_.status());
402    CHECK_NOTNULL(z.node());
403    std::vector<Output> grad_outputs;
404    // Call MatMulGrad which populates 'grad_outputs'.
405    TF_ASSERT_OK(test::CallGradFunction(root_, Operation(z.node()), {dz},
406                                        &grad_outputs));
407    ASSERT_EQ(2, grad_outputs.size());
408    // Run graph and return MatMul gradient tensors for 'dx' and 'dy' in 'out'.
409    test::GetTensors(root_, {grad_outputs[0], grad_outputs[1]}, out);
410  }
411
412  Tensor ComputeMatMul(const bool is_batch, const Output& x, const bool t_x,
413                       const Output& y, const bool t_y) {
414    Output z;
415    if (is_batch) {
416      z = BatchMatMul(root_, x, y, BatchMatMul::AdjX(t_x).AdjY(t_y));
417    } else {
418      z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y));
419    }
420    TF_EXPECT_OK(root_.status());
421    Tensor out;
422    test::GetTensor(root_, z, &out);
423    return out;
424  }
425
426  void RandMatMulGradData(const bool is_batch, const bool tx, const bool ty,
427                          std::vector<Tensor>* data) {
428    // Choose a random batch size in [1, 4]
429    const int b = 1 + (random::New64() % 4);
430    // z = MatMul(x, y)
431    const int m = Rand();
432    const int k = Rand();
433    const int n = Rand();
434
435    TensorShape x_shape;
436    if (is_batch) {
437      // x.shape = [b, m, k]
438      x_shape = tx ? TensorShape({b, k, m}) : TensorShape({b, m, k});
439    } else {
440      // x.shape = [m, k]
441      x_shape = tx ? TensorShape({k, m}) : TensorShape({m, k});
442    }
443    data->emplace_back(DT_FLOAT, x_shape);
444    RandTensor(&data->back());
445
446    TensorShape y_shape;
447    if (is_batch) {
448      // y.shape = [b, k, n]
449      y_shape = ty ? TensorShape({b, n, k}) : TensorShape({b, k, n});
450    } else {
451      // y.shape = [k, n]
452      y_shape = ty ? TensorShape({n, k}) : TensorShape({k, n});
453    }
454    data->emplace_back(DT_FLOAT, y_shape);
455    RandTensor(&data->back());
456
457    TensorShape z_shape;
458    if (is_batch) {
459      // z.shape = [b, m, n]
460      z_shape = TensorShape({b, m, n});
461    } else {
462      // z.shape = [m, n]
463      z_shape = TensorShape({m, n});
464    }
465    data->emplace_back(DT_FLOAT, z_shape);
466    RandTensor(&data->back());
467  }
468
469  void RandTensor(Tensor* t) {
470    test::FillFn<float>(
471        t, [this](const int i) { return static_cast<float>(Rand()); });
472  }
473
474  int Rand() { return 1 + (random::New64() % 10); }
475
476  Scope root_;
477};
478
479TEST_F(MathGradTest, MatMulGrad_NoTranspose) {
480  TestMatMulGrad(false, false, false);
481}
482
483TEST_F(MathGradTest, MatMulGrad_TransposeX) {
484  TestMatMulGrad(false, true, false);
485}
486
487TEST_F(MathGradTest, MatMulGrad_TransposeY) {
488  TestMatMulGrad(false, false, true);
489}
490
491TEST_F(MathGradTest, MatMulGrad_TransposeX_TransposeY) {
492  TestMatMulGrad(false, true, true);
493}
494
495TEST_F(MathGradTest, BatchMatMulGrad_NoTranspose) {
496  TestMatMulGrad(true, false, false);
497}
498
499TEST_F(MathGradTest, BatchMatMulGrad_TransposeX) {
500  TestMatMulGrad(true, true, false);
501}
502
503TEST_F(MathGradTest, BatchMatMulGrad_TransposeY) {
504  TestMatMulGrad(true, false, true);
505}
506
507TEST_F(MathGradTest, BatchMatMulGrad_TransposeX_TransposeY) {
508  TestMatMulGrad(true, true, true);
509}
510
511}  // namespace
512}  // namespace tensorflow
513