math_grad_test.cc revision 5a1d6d9dac79b46f055462ee52125753524d9f6e
1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/cc/framework/grad_op_registry.h"
17#include "tensorflow/cc/framework/gradient_checker.h"
18#include "tensorflow/cc/framework/testutil.h"
19#include "tensorflow/cc/gradients/grad_testutil.h"
20#include "tensorflow/cc/ops/standard_ops.h"
21#include "tensorflow/core/framework/tensor_testutil.h"
22#include "tensorflow/core/lib/core/status_test_util.h"
23#include "tensorflow/core/lib/random/random.h"
24
25namespace tensorflow {
26using namespace ops;  // NOLINT(build/namespaces)
27
28namespace {
29
30// TODO(andydavis) Test gradient function against numeric gradients output.
31// TODO(andydavis) As more gradients are added move common test functions
32// to a testutil library.
33
34class CWiseUnaryGradTest : public ::testing::Test {
35 protected:
36  CWiseUnaryGradTest() : scope_(Scope::NewRootScope().WithDevice("/cpu:0")) {}
37
38  enum UnaryOpType {
39    ABS,
40    NEG,
41    INV,
42    SQUARE,
43    SQRT,
44    RSQRT,
45    EXP,
46    EXPM1,
47    LOG,
48    LOG1P,
49    SINH,
50    COSH,
51    TANH,
52    ASINH,
53    ACOSH,
54    ATANH,
55    SIGMOID,
56    SIGN,
57    SIN,
58    COS,
59    ASIN,
60    ACOS,
61    TAN,
62    ATAN
63  };
64
65  template <typename T>
66  void TestCWiseGrad(UnaryOpType op_type, const std::function<T(int)>& x_fn,
67                     const std::function<T(const T&)>& dy_fn,
68                     const std::function<T(const T&, const T&)>& dx_fn) {
69    DataType dtype = DataTypeToEnum<T>::v();
70    Tensor x(dtype, {2, 3, 2});
71    auto x_flat = x.flat<T>();
72    for (int i = 0; i < x_flat.size(); ++i) {
73      x_flat(i) = x_fn(i);
74    }
75
76    Tensor dy(dtype, {2, 3, 2});
77    auto dy_flat = dy.flat<T>();
78    for (int i = 0; i < dy_flat.size(); ++i) {
79      dy_flat(i) = dy_fn(x_flat(i));
80    }
81
82    Tensor dx(dtype, {2, 3, 2});
83    auto dx_flat = dx.flat<T>();
84    for (int i = 0; i < dx_flat.size(); ++i) {
85      dx_flat(i) = dx_fn(x_flat(i), dy_flat(i));
86    }
87
88    Output y;
89    switch (op_type) {
90      case ABS:
91        y = Abs(scope_, x);
92        break;
93      case NEG:
94        y = Neg(scope_, x);
95        break;
96      case INV:
97        y = Reciprocal(scope_, x);
98        break;
99      case SQUARE:
100        y = Square(scope_, x);
101        break;
102      case SQRT:
103        y = Sqrt(scope_, x);
104        break;
105      case RSQRT:
106        y = Rsqrt(scope_, x);
107        break;
108      case EXP:
109        y = Exp(scope_, x);
110        break;
111      case EXPM1:
112        y = Expm1(scope_, x);
113        break;
114      case LOG:
115        y = Log(scope_, x);
116        break;
117      case LOG1P:
118        y = Log1p(scope_, x);
119        break;
120      case SINH:
121        y = Sinh(scope_, x);
122        break;
123      case COSH:
124        y = Cosh(scope_, x);
125        break;
126      case TANH:
127        y = Tanh(scope_, x);
128        break;
129      case ASINH:
130        y = Asinh(scope_, x);
131        break;
132      case ACOSH:
133        y = Acosh(scope_, x);
134        break;
135      case ATANH:
136        y = Atanh(scope_, x);
137        break;
138      case SIGMOID:
139        y = Sigmoid(scope_, x);
140        break;
141      case SIGN:
142        y = Sign(scope_, x);
143        break;
144      case SIN:
145        y = Sin(scope_, x);
146        break;
147      case COS:
148        y = Cos(scope_, x);
149        break;
150      case ASIN:
151        y = Asin(scope_, x);
152        break;
153      case ACOS:
154        y = Acos(scope_, x);
155        break;
156      case TAN:
157        y = Tan(scope_, x);
158        break;
159      case ATAN:
160        y = Atan(scope_, x);
161        break;
162    }
163
164    std::vector<Output> grad_outputs;
165    TF_ASSERT_OK(test::CallGradFunction(
166        scope_, Operation(y.node()), {ops::Const(scope_, dy)}, &grad_outputs));
167    Tensor output;
168    test::GetTensor(scope_, grad_outputs[0], &output);
169    test::ExpectClose(output, dx);
170  }
171
172  float RV(const std::vector<float>& v) {
173    return v[random::New64() % v.size()];
174  }
175
176  complex64 CRV(const std::vector<complex64>& v) {
177    return v[random::New64() % v.size()];
178  }
179
180  complex64 conjugate(const complex64& val) {
181    return complex64(val.real(), -val.imag());
182  }
183
184  const complex64 one_{1.0, 0};
185
186  Scope scope_;
187};
188
189TEST_F(CWiseUnaryGradTest, Abs) {
190  auto x_fn = [this](const int i) { return RV({-1, 0, 1}); };
191  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
192  auto dx_fn = [this](const float x, const float dy) { return x * dy; };
193  TestCWiseGrad<float>(ABS, x_fn, dy_fn, dx_fn);
194}
195
196TEST_F(CWiseUnaryGradTest, Neg) {
197  auto x_fn = [this](const int i) { return RV({-1, 0, 1}); };
198  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
199  auto dx_fn = [this](const float x, const float dy) { return -dy; };
200  TestCWiseGrad<float>(NEG, x_fn, dy_fn, dx_fn);
201}
202
203TEST_F(CWiseUnaryGradTest, Reciprocal) {
204  auto x_fn = [this](const int i) { return RV({-1, 1, -2, 2, -3, 3, -4, 4}); };
205  auto dy_fn = [this](const float x) { return RV({0, -2, 2, -3, 3, -4, 4}); };
206  auto dx_fn = [this](const float x, const float dy) {
207    return -(1 / (x * x)) * dy;
208  };
209  TestCWiseGrad<float>(INV, x_fn, dy_fn, dx_fn);
210}
211
212TEST_F(CWiseUnaryGradTest, Reciprocal_Complex) {
213  auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); };
214  auto dy_fn = [this](const complex64 x) {
215    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
216  };
217  auto dx_fn = [this](const complex64 x, const complex64 dy) {
218    return -conjugate(one_ / (x * x)) * dy;
219  };
220  TestCWiseGrad<complex64>(INV, x_fn, dy_fn, dx_fn);
221}
222
223TEST_F(CWiseUnaryGradTest, Square) {
224  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
225  auto dy_fn = [this](const float x) { return RV({0, -7, 7, -8, 8, -9, 9}); };
226  auto dx_fn = [this](const float x, const float dy) { return 2 * x * dy; };
227  TestCWiseGrad<float>(SQUARE, x_fn, dy_fn, dx_fn);
228}
229
230TEST_F(CWiseUnaryGradTest, Square_Complex) {
231  auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); };
232  auto dy_fn = [this](const complex64& x) {
233    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
234  };
235  auto dx_fn = [this](const complex64& x, const complex64& dy) {
236    return conjugate(complex64(2, 0) * x) * dy;
237  };
238  TestCWiseGrad<complex64>(SQUARE, x_fn, dy_fn, dx_fn);
239}
240
241TEST_F(CWiseUnaryGradTest, Sqrt) {
242  auto x_fn = [this](const int i) { return RV({0, 1, 2, 3, 4, 5, 6, 7}); };
243  auto dy_fn = [this](const float x) { return x + RV({8, 9, 10, 11, 12, 13}); };
244  auto dx_fn = [this](const float x, const float dy) {
245    return dy * 0.5 * (1.0 / std::sqrt(x));
246  };
247  TestCWiseGrad<float>(SQRT, x_fn, dy_fn, dx_fn);
248}
249
250TEST_F(CWiseUnaryGradTest, Sqrt_Complex) {
251  auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); };
252  auto dy_fn = [this](const complex64& x) {
253    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
254  };
255  auto dx_fn = [this](const complex64& x, const complex64& dy) {
256    return conjugate(complex64(0.5, 0) / std::sqrt(x)) * dy;
257  };
258  TestCWiseGrad<complex64>(SQRT, x_fn, dy_fn, dx_fn);
259}
260
261TEST_F(CWiseUnaryGradTest, Rsqrt) {
262  auto x_fn = [this](const int i) { return RV({1, 2, 3, 4, 5, 6, 7, 8}); };
263  auto dy_fn = [this](const float x) { return x + RV({8, 9, 10, 11, 12, 13}); };
264  auto dx_fn = [this](const float x, const float dy) {
265    return dy * -0.5 * (1 / std::sqrt(x)) * (1 / x);
266  };
267  TestCWiseGrad<float>(RSQRT, x_fn, dy_fn, dx_fn);
268}
269
270TEST_F(CWiseUnaryGradTest, Rsqrt_Complex) {
271  auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); };
272  auto dy_fn = [this](const complex64& x) {
273    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
274  };
275  auto dx_fn = [this](const complex64& x, const complex64& dy) {
276    return conjugate(complex64(-0.5, 0) / std::sqrt(x) / x) * dy;
277  };
278  TestCWiseGrad<complex64>(RSQRT, x_fn, dy_fn, dx_fn);
279}
280
281TEST_F(CWiseUnaryGradTest, Exp) {
282  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
283  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
284  auto dx_fn = [this](const float x, const float dy) {
285    return dy * std::exp(x);
286  };
287  TestCWiseGrad<float>(EXP, x_fn, dy_fn, dx_fn);
288}
289
290TEST_F(CWiseUnaryGradTest, Exp_Complex) {
291  auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); };
292  auto dy_fn = [this](const complex64& x) {
293    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
294  };
295  auto dx_fn = [this](const complex64& x, const complex64& dy) {
296    return dy * conjugate(std::exp(x));
297  };
298  TestCWiseGrad<complex64>(EXP, x_fn, dy_fn, dx_fn);
299}
300
301TEST_F(CWiseUnaryGradTest, Expm1) {
302  auto x_fn = [this](const int i) { return RV({0, -1, 1e-6, 1, -2, 3, 100}); };
303  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
304  auto dx_fn = [this](const float x, const float dy) {
305    return dy * std::exp(x);
306  };
307  TestCWiseGrad<float>(EXPM1, x_fn, dy_fn, dx_fn);
308}
309
310TEST_F(CWiseUnaryGradTest, Expm1_Complex) {
311  auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); };
312  auto dy_fn = [this](const complex64& x) {
313    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
314  };
315  auto dx_fn = [this](const complex64& x, const complex64& dy) {
316    return dy * conjugate(std::exp(x));
317  };
318  TestCWiseGrad<complex64>(EXPM1, x_fn, dy_fn, dx_fn);
319}
320
321TEST_F(CWiseUnaryGradTest, Log) {
322  auto x_fn = [this](const int i) { return RV({-1, 1, -2, 2, -3, 3, -4, 4}); };
323  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
324  auto dx_fn = [this](const float x, const float dy) { return dy * (1.0 / x); };
325  TestCWiseGrad<float>(LOG, x_fn, dy_fn, dx_fn);
326}
327
328TEST_F(CWiseUnaryGradTest, Log_Complex) {
329  auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); };
330  auto dy_fn = [this](const complex64& x) {
331    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
332  };
333  auto dx_fn = [this](const complex64& x, const complex64& dy) {
334    return dy * conjugate(one_ / x);
335  };
336  TestCWiseGrad<complex64>(LOG, x_fn, dy_fn, dx_fn);
337}
338
339TEST_F(CWiseUnaryGradTest, Log1p) {
340  auto x_fn = [this](const int i) { return RV({0, 1e-6, 1, 2, 3, 4, 100}); };
341  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
342  auto dx_fn = [this](const float x, const float dy) {
343    return dy * (1.0 / (1.0 + x));
344  };
345  TestCWiseGrad<float>(LOG1P, x_fn, dy_fn, dx_fn);
346}
347
348TEST_F(CWiseUnaryGradTest, Log1p_Complex) {
349  auto x_fn = [this](const int i) {
350    return CRV({{0, 0}, {1e-6, 0}, {2, -1}, {1, 2}, {3, 4}});
351  };
352  auto dy_fn = [this](const complex64& x) {
353    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
354  };
355  auto dx_fn = [this](const complex64& x, const complex64& dy) {
356    return dy / (one_ + conjugate(x));
357  };
358  TestCWiseGrad<complex64>(LOG1P, x_fn, dy_fn, dx_fn);
359}
360
361TEST_F(CWiseUnaryGradTest, Sinh) {
362  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
363  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
364  auto dx_fn = [this](const float x, const float dy) {
365    return dy * std::cosh(x);
366  };
367  TestCWiseGrad<float>(SINH, x_fn, dy_fn, dx_fn);
368}
369
370TEST_F(CWiseUnaryGradTest, Sinh_Complex) {
371  auto x_fn = [this](const int i) {
372    return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
373  };
374  auto dy_fn = [this](const complex64& x) {
375    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
376  };
377  auto dx_fn = [this](const complex64& x, const complex64& dy) {
378    return dy * conjugate(std::cosh(x));
379  };
380  TestCWiseGrad<complex64>(SINH, x_fn, dy_fn, dx_fn);
381}
382
383TEST_F(CWiseUnaryGradTest, Cosh) {
384  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
385  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
386  auto dx_fn = [this](const float x, const float dy) {
387    return dy * std::sinh(x);
388  };
389  TestCWiseGrad<float>(COSH, x_fn, dy_fn, dx_fn);
390}
391
392TEST_F(CWiseUnaryGradTest, Cosh_Complex) {
393  auto x_fn = [this](const int i) {
394    return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
395  };
396  auto dy_fn = [this](const complex64& x) {
397    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
398  };
399  auto dx_fn = [this](const complex64& x, const complex64& dy) {
400    return dy * conjugate(std::sinh(x));
401  };
402  TestCWiseGrad<complex64>(COSH, x_fn, dy_fn, dx_fn);
403}
404
405TEST_F(CWiseUnaryGradTest, Tanh) {
406  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
407  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
408  auto dx_fn = [this](const float x, const float dy) {
409    const float y = std::tanh(x);
410    return dy * (1.0 - y * y);
411  };
412  TestCWiseGrad<float>(TANH, x_fn, dy_fn, dx_fn);
413}
414
415TEST_F(CWiseUnaryGradTest, Tanh_Complex) {
416  auto x_fn = [this](const int i) {
417    return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
418  };
419  auto dy_fn = [this](const complex64& x) {
420    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
421  };
422  auto dx_fn = [this](const complex64& x, const complex64& dy) {
423    const complex64 y = std::tanh(x);
424    return dy * conjugate((one_ - y * y));
425  };
426  TestCWiseGrad<complex64>(TANH, x_fn, dy_fn, dx_fn);
427}
428
429TEST_F(CWiseUnaryGradTest, Asinh) {
430  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
431  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
432  auto dx_fn = [this](const float x, const float dy) {
433    auto y = std::asinh(x);
434    return dy / std::cosh(y);
435  };
436  TestCWiseGrad<float>(ASINH, x_fn, dy_fn, dx_fn);
437}
438
439TEST_F(CWiseUnaryGradTest, Asinh_Complex) {
440  auto x_fn = [this](const int i) {
441    return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
442  };
443  auto dy_fn = [this](const complex64& x) {
444    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
445  };
446  auto dx_fn = [this](const complex64& x, const complex64& dy) {
447    auto y = std::asinh(x);
448    return dy / conjugate(std::cosh(y));
449  };
450  TestCWiseGrad<complex64>(ASINH, x_fn, dy_fn, dx_fn);
451}
452
453TEST_F(CWiseUnaryGradTest, Acosh) {
454  auto x_fn = [this](const int i) { return RV({1, 2, 3, 4, 5, 6, 7}); };
455  auto dy_fn = [this](const float x) {
456    return x + RV({8, 9, 10, 11, 12, 13, 14});
457  };
458  auto dx_fn = [this](const float x, const float dy) {
459    auto y = std::acosh(x);
460    return dy / std::sinh(y);
461  };
462  TestCWiseGrad<float>(ACOSH, x_fn, dy_fn, dx_fn);
463}
464
465TEST_F(CWiseUnaryGradTest, Acosh_Complex) {
466  auto x_fn = [this](const int i) {
467    return CRV({{1, 1}, {2, 1}, {1, 4}, {1, 2}, {3, 4}});
468  };
469  auto dy_fn = [this](const complex64& x) {
470    return x + CRV({{2, 2}, {3, 3}, {1, 4}});
471  };
472  auto dx_fn = [this](const complex64& x, const complex64& dy) {
473    auto y = std::acosh(x);
474    return dy / conjugate(std::sinh(y));
475  };
476  TestCWiseGrad<complex64>(ACOSH, x_fn, dy_fn, dx_fn);
477}
478
479TEST_F(CWiseUnaryGradTest, Atanh) {
480  auto x_fn = [this](const int i) { return RV({0, -0.5, 0.5, -0.1, 0.1}); };
481  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
482  auto dx_fn = [this](const float x, const float dy) {
483    return dy * (1. / (1. - x * x));
484  };
485  TestCWiseGrad<float>(ATANH, x_fn, dy_fn, dx_fn);
486}
487
488TEST_F(CWiseUnaryGradTest, Atanh_Complex) {
489  auto x_fn = [this](const int i) {
490    return CRV({{0.1, 0}, {0, 0.1}, {0.2, -0.1}, {0.1, 0.2}, {0.3, 0.4}});
491  };
492  auto dy_fn = [this](const complex64& x) {
493    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
494  };
495  auto dx_fn = [this](const complex64& x, const complex64& dy) {
496    return dy / conjugate(one_ - x * x);
497  };
498  TestCWiseGrad<complex64>(ATANH, x_fn, dy_fn, dx_fn);
499}
500
501TEST_F(CWiseUnaryGradTest, Sigmoid) {
502  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
503  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
504  auto dx_fn = [this](const float x, const float dy) {
505    const float y = 1.0 / (1.0 + std::exp(-x));
506    return dy * y * (1.0 - y);
507  };
508  TestCWiseGrad<float>(SIGMOID, x_fn, dy_fn, dx_fn);
509}
510
511TEST_F(CWiseUnaryGradTest, Sigmoid_Complex) {
512  auto x_fn = [this](const int i) {
513    return CRV({{1, 0}, {0, 0}, {2, -1}, {1, 2}, {3, 4}});
514  };
515  auto dy_fn = [this](const complex64& x) {
516    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
517  };
518  auto dx_fn = [this](const complex64& x, const complex64& dy) {
519    const complex64 y = one_ / (one_ + std::exp(-x));
520    return dy * conjugate(y * (one_ - y));
521  };
522  TestCWiseGrad<complex64>(SIGMOID, x_fn, dy_fn, dx_fn);
523}
524
525TEST_F(CWiseUnaryGradTest, Sign) {
526  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
527  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
528  auto dx_fn = [this](const float x, const float dy) { return 0.0; };
529  TestCWiseGrad<float>(SIGN, x_fn, dy_fn, dx_fn);
530}
531
532TEST_F(CWiseUnaryGradTest, Sin) {
533  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
534  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
535  auto dx_fn = [this](const float x, const float dy) {
536    return dy * std::cos(x);
537  };
538  TestCWiseGrad<float>(SIN, x_fn, dy_fn, dx_fn);
539}
540
541TEST_F(CWiseUnaryGradTest, Sin_Complex) {
542  auto x_fn = [this](const int i) {
543    return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
544  };
545  auto dy_fn = [this](const complex64& x) {
546    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
547  };
548  auto dx_fn = [this](const complex64& x, const complex64& dy) {
549    return dy * conjugate(std::cos(x));
550  };
551  TestCWiseGrad<complex64>(SIN, x_fn, dy_fn, dx_fn);
552}
553
554TEST_F(CWiseUnaryGradTest, Cos) {
555  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
556  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
557  auto dx_fn = [this](const float x, const float dy) {
558    return dy * -1.0 * std::sin(x);
559  };
560  TestCWiseGrad<float>(COS, x_fn, dy_fn, dx_fn);
561}
562
563TEST_F(CWiseUnaryGradTest, Cos_Complex) {
564  auto x_fn = [this](const int i) {
565    return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
566  };
567  auto dy_fn = [this](const complex64& x) {
568    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
569  };
570  auto dx_fn = [this](const complex64& x, const complex64& dy) {
571    return dy * conjugate(-std::sin(x));
572  };
573  TestCWiseGrad<complex64>(COS, x_fn, dy_fn, dx_fn);
574}
575
576TEST_F(CWiseUnaryGradTest, Asin) {
577  auto x_fn = [this](const int i) { return RV({0, -0.5, 0.5, -1, 1}); };
578  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
579  auto dx_fn = [this](const float x, const float dy) {
580    return dy * (1.0 / std::sqrt(1.0 - x * x));
581  };
582  TestCWiseGrad<float>(ASIN, x_fn, dy_fn, dx_fn);
583}
584
585TEST_F(CWiseUnaryGradTest, Asin_Complex) {
586  auto x_fn = [this](const int i) {
587    return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
588  };
589  auto dy_fn = [this](const complex64& x) {
590    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
591  };
592  auto dx_fn = [this](const complex64& x, const complex64& dy) {
593    return dy / conjugate(std::sqrt(one_ - x * x));
594  };
595  // TODO(kbsriram)
596  // Enable test when the asin kernel supports complex numbers
597  if (false) {
598    TestCWiseGrad<complex64>(ASIN, x_fn, dy_fn, dx_fn);
599  }
600}
601
602TEST_F(CWiseUnaryGradTest, Acos) {
603  auto x_fn = [this](const int i) { return RV({0, -0.5, 0.5, -1, 1}); };
604  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
605  auto dx_fn = [this](const float x, const float dy) {
606    return dy * (-1.0 / std::sqrt(1.0 - x * x));
607  };
608  TestCWiseGrad<float>(ACOS, x_fn, dy_fn, dx_fn);
609}
610
611TEST_F(CWiseUnaryGradTest, Acos_Complex) {
612  auto x_fn = [this](const int i) {
613    return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
614  };
615  auto dy_fn = [this](const complex64& x) {
616    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
617  };
618  auto dx_fn = [this](const complex64& x, const complex64& dy) {
619    return dy / -conjugate(std::sqrt(one_ - x * x));
620  };
621  // TODO(kbsriram)
622  // Add test when the acos kernel supports complex numbers
623  if (false) {
624    TestCWiseGrad<complex64>(ACOS, x_fn, dy_fn, dx_fn);
625  }
626}
627
628TEST_F(CWiseUnaryGradTest, Tan) {
629  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
630  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
631  auto dx_fn = [this](const float x, const float dy) {
632    const float cosx = std::cos(x);
633    return dy * (1 / (cosx * cosx));
634  };
635  TestCWiseGrad<float>(TAN, x_fn, dy_fn, dx_fn);
636}
637
638TEST_F(CWiseUnaryGradTest, Tan_Complex) {
639  auto x_fn = [this](const int i) {
640    return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
641  };
642  auto dy_fn = [this](const complex64& x) {
643    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
644  };
645  auto dx_fn = [this](const complex64& x, const complex64& dy) {
646    const complex64 cosx = std::cos(x);
647    return dy / conjugate(cosx * cosx);
648  };
649  // TODO(kbsriram)
650  // Enable when tan kernel supports complex inputs
651  if (false) {
652    TestCWiseGrad<complex64>(TAN, x_fn, dy_fn, dx_fn);
653  }
654}
655
656TEST_F(CWiseUnaryGradTest, Atan) {
657  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
658  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
659  auto dx_fn = [this](const float x, const float dy) {
660    return dy * (1 / (1 + x * x));
661  };
662  TestCWiseGrad<float>(ATAN, x_fn, dy_fn, dx_fn);
663}
664
665TEST_F(CWiseUnaryGradTest, Atan_Complex) {
666  auto x_fn = [this](const int i) {
667    return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
668  };
669  auto dy_fn = [this](const complex64& x) {
670    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
671  };
672  auto dx_fn = [this](const complex64& x, const complex64& dy) {
673    return dy / (one_ + x * x);
674  };
675  // TODO(kbsriram)
676  // Add test when the atan kernel supports complex numbers
677  if (false) {
678    TestCWiseGrad<complex64>(ATAN, x_fn, dy_fn, dx_fn);
679  }
680}
681
682class CWiseUnaryComplexGradTest : public ::testing::Test {
683 protected:
684  CWiseUnaryComplexGradTest()
685      : scope_(Scope::NewRootScope().WithDevice("/cpu:0")) {}
686
687  enum UnaryOpType { REAL, IMAG, ANGLE, CONJ };
688
689  void TestCWiseGradComplex(UnaryOpType op_type, const Tensor& x,
690                            const Tensor& dy, const Tensor& dx_expected) {
691    Output y;
692    switch (op_type) {
693      case REAL:
694        y = Real(scope_, x);
695        break;
696      case IMAG:
697        y = Imag(scope_, x);
698        break;
699      case ANGLE:
700        y = Angle(scope_, x);
701        break;
702      case CONJ:
703        y = Conj(scope_, x);
704        break;
705    }
706
707    std::vector<Output> grad_outputs;
708    TF_ASSERT_OK(test::CallGradFunction(
709        scope_, Operation(y.node()), {ops::Const(scope_, dy)}, &grad_outputs));
710    Tensor dx;
711    test::GetTensor(scope_, grad_outputs[0], &dx);
712    test::ExpectClose(dx, dx_expected);
713  }
714
715  Scope scope_;
716};
717
718TEST_F(CWiseUnaryComplexGradTest, Real) {
719  Tensor x = test::AsTensor<complex64>(
720      {{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3});
721  Tensor dy = test::AsTensor<float>({11, -12, 13, -14, 15, -16}, {2, 3});
722  Tensor dx_expected = test::AsTensor<complex64>(
723      {{11, 0}, {-12, 0}, {13, 0}, {-14, 0}, {15, 0}, {-16, 0}}, {2, 3});
724  TestCWiseGradComplex(REAL, x, dy, dx_expected);
725}
726
727TEST_F(CWiseUnaryComplexGradTest, Imag) {
728  Tensor x = test::AsTensor<complex64>(
729      {{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3});
730  Tensor dy = test::AsTensor<float>({11, -12, 13, -14, 15, -16}, {2, 3});
731  Tensor dx_expected = test::AsTensor<complex64>(
732      {{0, 11}, {0, -12}, {0, 13}, {0, -14}, {0, 15}, {0, -16}}, {2, 3});
733  TestCWiseGradComplex(IMAG, x, dy, dx_expected);
734}
735
736TEST_F(CWiseUnaryComplexGradTest, Angle) {
737  Tensor x = test::AsTensor<complex64>(
738      {{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3});
739  Tensor dy = test::AsTensor<float>({11, -12, 13, -14, 15, -16}, {2, 3});
740  Tensor dx_expected = test::AsTensor<complex64>(
741      {{5.5, 5.5}, {3, 3},
742       {2.1666666666666665, 2.1666666666666665}, {1.75, 1.75},
743       {0.9375, 0.9375}, {0.8888888888888888, 0.8888888888888888}}, {2, 3});
744  TestCWiseGradComplex(ANGLE, x, dy, dx_expected);
745}
746
747TEST_F(CWiseUnaryComplexGradTest, Conj) {
748  Tensor x = test::AsTensor<complex64>(
749      {{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3});
750  Tensor dy = test::AsTensor<complex64>(
751      {{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3});
752  Tensor dx_expected = test::AsTensor<complex64>(
753      {{1, 1}, {-2, -2}, {3, 3}, {-4, -4}, {8, 8}, {-9, -9}}, {2, 3});
754  TestCWiseGradComplex(CONJ, x, dy, dx_expected);
755}
756
757class MathGradTest : public ::testing::Test {
758 protected:
759  MathGradTest() : root_(Scope::NewRootScope().WithDevice("/cpu:0")) {}
760
761  void TestMatMulGrad(const bool is_batch, const bool t_x, const bool t_y) {
762    // Generate random test data.
763    std::vector<Tensor> data;
764    RandMatMulGradData(is_batch, t_x, t_y, &data);
765    auto x = Const(root_, data[0]);
766    auto y = Const(root_, data[1]);
767    auto dz = Const(root_, data[2]);
768
769    std::vector<Tensor> grad_outputs;
770    ComputeMatMulGrad(is_batch, x, t_x, y, t_y, dz, &grad_outputs);
771
772    if (!t_x && !t_y) {
773      test::ExpectClose(grad_outputs[0],
774                        ComputeMatMul(is_batch, dz, false, y, true));
775      test::ExpectClose(grad_outputs[1],
776                        ComputeMatMul(is_batch, x, true, dz, false));
777    } else if (t_x && !t_y) {
778      test::ExpectClose(grad_outputs[0],
779                        ComputeMatMul(is_batch, y, false, dz, true));
780      test::ExpectClose(grad_outputs[1],
781                        ComputeMatMul(is_batch, x, false, dz, false));
782    } else if (!t_x && t_y) {
783      test::ExpectClose(grad_outputs[0],
784                        ComputeMatMul(is_batch, dz, false, y, false));
785      test::ExpectClose(grad_outputs[1],
786                        ComputeMatMul(is_batch, dz, true, x, false));
787    } else {
788      test::ExpectClose(grad_outputs[0],
789                        ComputeMatMul(is_batch, y, true, dz, true));
790      test::ExpectClose(grad_outputs[1],
791                        ComputeMatMul(is_batch, dz, true, x, true));
792    }
793  }
794
795  void ComputeMatMulGrad(const bool is_batch, const Output& x, const bool t_x,
796                         const Output& y, const bool t_y, const Output& dz,
797                         std::vector<Tensor>* out) {
798    // Compute forward MatMul: z = MatMul(x, y).
799    Output z;
800    if (is_batch) {
801      z = BatchMatMul(root_, x, y, BatchMatMul::AdjX(t_x).AdjY(t_y));
802    } else {
803      z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y));
804    }
805    TF_ASSERT_OK(root_.status());
806    CHECK_NOTNULL(z.node());
807    std::vector<Output> grad_outputs;
808    // Call MatMulGrad which populates 'grad_outputs'.
809    TF_ASSERT_OK(test::CallGradFunction(root_, Operation(z.node()), {dz},
810                                        &grad_outputs));
811    ASSERT_EQ(2, grad_outputs.size());
812    // Run graph and return MatMul gradient tensors for 'dx' and 'dy' in 'out'.
813    test::GetTensors(root_, {grad_outputs[0], grad_outputs[1]}, out);
814  }
815
816  Tensor ComputeMatMul(const bool is_batch, const Output& x, const bool t_x,
817                       const Output& y, const bool t_y) {
818    Output z;
819    if (is_batch) {
820      z = BatchMatMul(root_, x, y, BatchMatMul::AdjX(t_x).AdjY(t_y));
821    } else {
822      z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y));
823    }
824    TF_EXPECT_OK(root_.status());
825    Tensor out;
826    test::GetTensor(root_, z, &out);
827    return out;
828  }
829
830  void RandMatMulGradData(const bool is_batch, const bool tx, const bool ty,
831                          std::vector<Tensor>* data) {
832    // Choose a random batch size in [1, 4]
833    const int b = 1 + (random::New64() % 4);
834    // z = MatMul(x, y)
835    const int m = Rand();
836    const int k = Rand();
837    const int n = Rand();
838
839    TensorShape x_shape;
840    if (is_batch) {
841      // x.shape = [b, m, k]
842      x_shape = tx ? TensorShape({b, k, m}) : TensorShape({b, m, k});
843    } else {
844      // x.shape = [m, k]
845      x_shape = tx ? TensorShape({k, m}) : TensorShape({m, k});
846    }
847    data->emplace_back(DT_FLOAT, x_shape);
848    RandTensor(&data->back());
849
850    TensorShape y_shape;
851    if (is_batch) {
852      // y.shape = [b, k, n]
853      y_shape = ty ? TensorShape({b, n, k}) : TensorShape({b, k, n});
854    } else {
855      // y.shape = [k, n]
856      y_shape = ty ? TensorShape({n, k}) : TensorShape({k, n});
857    }
858    data->emplace_back(DT_FLOAT, y_shape);
859    RandTensor(&data->back());
860
861    TensorShape z_shape;
862    if (is_batch) {
863      // z.shape = [b, m, n]
864      z_shape = TensorShape({b, m, n});
865    } else {
866      // z.shape = [m, n]
867      z_shape = TensorShape({m, n});
868    }
869    data->emplace_back(DT_FLOAT, z_shape);
870    RandTensor(&data->back());
871  }
872
873  void RandTensor(Tensor* t) {
874    test::FillFn<float>(
875        t, [this](const int i) { return static_cast<float>(Rand()); });
876  }
877
878  int Rand() { return 1 + (random::New64() % 10); }
879
880  Scope root_;
881};
882
883TEST_F(MathGradTest, MatMulGrad_NoTranspose) {
884  TestMatMulGrad(false, false, false);
885}
886
887TEST_F(MathGradTest, MatMulGrad_TransposeX) {
888  TestMatMulGrad(false, true, false);
889}
890
891TEST_F(MathGradTest, MatMulGrad_TransposeY) {
892  TestMatMulGrad(false, false, true);
893}
894
895TEST_F(MathGradTest, MatMulGrad_TransposeX_TransposeY) {
896  TestMatMulGrad(false, true, true);
897}
898
899TEST_F(MathGradTest, BatchMatMulGrad_NoTranspose) {
900  TestMatMulGrad(true, false, false);
901}
902
903TEST_F(MathGradTest, BatchMatMulGrad_TransposeX) {
904  TestMatMulGrad(true, true, false);
905}
906
907TEST_F(MathGradTest, BatchMatMulGrad_TransposeY) {
908  TestMatMulGrad(true, false, true);
909}
910
911TEST_F(MathGradTest, BatchMatMulGrad_TransposeX_TransposeY) {
912  TestMatMulGrad(true, true, true);
913}
914
915class NaryGradTest : public ::testing::Test {
916 protected:
917  NaryGradTest() : scope_(Scope::NewRootScope().WithDevice("/cpu:0")) {}
918
919  void RunTest(const OutputList& xs, const std::vector<TensorShape>& x_shapes,
920               const OutputList& ys, const std::vector<TensorShape>& y_shapes) {
921    TF_ASSERT_OK(scope_.status());
922    float max_error;
923    TF_ASSERT_OK(
924        ComputeGradientError(scope_, xs, x_shapes, ys, y_shapes, &max_error));
925    EXPECT_LT(max_error, 1e-3);
926  }
927
928  void RunTest(const Output& x, const Tensor& x_init_value, const Output& y,
929               const TensorShape& y_shape) {
930    TF_ASSERT_OK(scope_.status());
931    float max_error;
932    TF_ASSERT_OK(
933        ComputeGradientError(scope_, x, x_init_value, y, y_shape, &max_error));
934    EXPECT_LT(max_error, 1e-3);
935  }
936
937  Scope scope_;
938};
939
940TEST_F(NaryGradTest, Sum) {
941  TensorShape x_shape({2, 3, 5, 7});
942  auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
943  auto y = Sum(scope_, x, {1, -1});
944  // y's shape is the result of reducing x along axes 1 and -1 (= 3)
945  TensorShape y_shape({2, 5});
946  RunTest({x}, {x_shape}, {y}, {y_shape});
947}
948
949TEST_F(NaryGradTest, Mean) {
950  TensorShape x_shape({2, 3, 5, 7});
951  auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
952  auto y = Mean(scope_, x, {1, -1});
953  // y's shape is the result of reducing x along axes 1 and -1 (= 3)
954  TensorShape y_shape({2, 5});
955  RunTest({x}, {x_shape}, {y}, {y_shape});
956}
957
958TEST_F(NaryGradTest, AddN) {
959  TensorShape shape({3, 2, 5});
960  std::vector<Output> xs;
961  xs.push_back(Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)));
962  xs.push_back(Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)));
963  xs.push_back(Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)));
964  auto y = AddN(scope_, xs);
965  RunTest(xs, {shape, shape, shape}, {y}, {shape});
966}
967
968TEST_F(NaryGradTest, Add) {
969  TensorShape x1_shape({3, 2, 5});
970  TensorShape x2_shape({2, 5});
971  auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x1_shape));
972  auto x2 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x2_shape));
973  auto y = Add(scope_, x1, x2);
974  RunTest({x1, x2}, {x1_shape, x2_shape}, {y}, {x1_shape});
975}
976
977TEST_F(NaryGradTest, Sub) {
978  TensorShape x1_shape({3, 2, 5});
979  TensorShape x2_shape({2, 5});
980  auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x1_shape));
981  auto x2 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x2_shape));
982  auto y = Sub(scope_, x1, x2);
983  RunTest({x1, x2}, {x1_shape, x2_shape}, {y}, {x1_shape});
984}
985
986TEST_F(NaryGradTest, Mul) {
987  TensorShape x1_shape({3, 2, 5});
988  TensorShape x2_shape({2, 5});
989  auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x1_shape));
990  auto x2 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x2_shape));
991  auto y = Mul(scope_, x1, x2);
992  RunTest({x1, x2}, {x1_shape, x2_shape}, {y}, {x1_shape});
993}
994
995TEST_F(NaryGradTest, Div) {
996  TensorShape x_shape({3, 2, 5});
997  auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
998  // Test x / (1 + |x|) rather than x_1 / x_2 to avoid triggering large
999  // division errors in the numeric estimator used by the gradient checker.
1000  auto y = Div(scope_, x, Add(scope_, Const<float>(scope_, 1), Abs(scope_, x)));
1001  RunTest({x}, {x_shape}, {y}, {x_shape});
1002}
1003
1004TEST_F(NaryGradTest, RealDiv) {
1005  TensorShape x_shape({3, 2, 5});
1006  auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
1007  // Test x / (1 + |x|) rather than x_1 / x_2 to avoid triggering large
1008  // division errors in the numeric estimator used by the gradient checker.
1009  auto y =
1010      RealDiv(scope_, x, Add(scope_, Const<float>(scope_, 1), Abs(scope_, x)));
1011  RunTest({x}, {x_shape}, {y}, {x_shape});
1012}
1013
1014TEST_F(NaryGradTest, SquaredDifference) {
1015  TensorShape x1_shape({3, 2, 5});
1016  TensorShape x2_shape({2, 5});
1017  auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x1_shape));
1018  auto x2 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x2_shape));
1019  auto y = SquaredDifference(scope_, x1, x2);
1020  RunTest({x1, x2}, {x1_shape, x2_shape}, {y}, {x1_shape});
1021}
1022
1023TEST_F(NaryGradTest, Maximum) {
1024  TensorShape shape({3, 2});
1025  auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
1026  auto y = Maximum(scope_, x, Const(scope_, 1.0f));
1027  // Select values away from 1.0f to avoid instability when computing
1028  // finite differences.
1029  Tensor x_init_value =
1030      test::AsTensor<float>({0.5f, 1.5f, -1.2f, 3.0f, 0.1f, 2.8f}, {3, 2});
1031  RunTest(x, x_init_value, y, shape);
1032}
1033
1034TEST_F(NaryGradTest, Minimum) {
1035  TensorShape shape({3, 2});
1036  auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
1037  auto y = Minimum(scope_, x, Const(scope_, 1.0f));
1038  // Select values away from 1.0f to avoid instability when computing
1039  // finite differences.
1040  Tensor x_init_value =
1041      test::AsTensor<float>({0.5f, 1.5f, -1.2f, 3.0f, 0.1f, 2.8f}, {3, 2});
1042  RunTest(x, x_init_value, y, shape);
1043}
1044
1045}  // namespace
1046}  // namespace tensorflow
1047