math_grad_test.cc revision 16001fc526831c7a7f1a3814f517b01008df4c4c
1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/cc/framework/grad_op_registry.h"
17#include "tensorflow/cc/framework/gradient_checker.h"
18#include "tensorflow/cc/framework/testutil.h"
19#include "tensorflow/cc/gradients/grad_testutil.h"
20#include "tensorflow/cc/ops/standard_ops.h"
21#include "tensorflow/core/framework/tensor_testutil.h"
22#include "tensorflow/core/lib/core/status_test_util.h"
23#include "tensorflow/core/lib/random/random.h"
24
25namespace tensorflow {
26using namespace ops;  // NOLINT(build/namespaces)
27
28namespace {
29
30// TODO(andydavis) Test gradient function against numeric gradients output.
31// TODO(andydavis) As more gradients are added move common test functions
32// to a testutil library.
33
34class CWiseUnaryGradTest : public ::testing::Test {
35 protected:
36  CWiseUnaryGradTest() : scope_(Scope::NewRootScope().WithDevice("/cpu:0")) {}
37
38  enum UnaryOpType {
39    ABS,
40    NEG,
41    INV,
42    SQUARE,
43    SQRT,
44    RSQRT,
45    EXP,
46    EXPM1,
47    LOG,
48    LOG1P,
49    SINH,
50    COSH,
51    TANH,
52    ASINH,
53    ACOSH,
54    ATANH,
55    SIGMOID,
56    SIGN,
57    SIN,
58    COS,
59    ASIN,
60    ACOS,
61    TAN,
62    ATAN
63  };
64
65  template <typename T>
66  void TestCWiseGrad(UnaryOpType op_type, const std::function<T(int)>& x_fn,
67                     const std::function<T(const T&)>& dy_fn,
68                     const std::function<T(const T&, const T&)>& dx_fn) {
69    DataType dtype = DataTypeToEnum<T>::v();
70    Tensor x(dtype, {2, 3, 2});
71    auto x_flat = x.flat<T>();
72    for (int i = 0; i < x_flat.size(); ++i) {
73      x_flat(i) = x_fn(i);
74    }
75
76    Tensor dy(dtype, {2, 3, 2});
77    auto dy_flat = dy.flat<T>();
78    for (int i = 0; i < dy_flat.size(); ++i) {
79      dy_flat(i) = dy_fn(x_flat(i));
80    }
81
82    Tensor dx(dtype, {2, 3, 2});
83    auto dx_flat = dx.flat<T>();
84    for (int i = 0; i < dx_flat.size(); ++i) {
85      dx_flat(i) = dx_fn(x_flat(i), dy_flat(i));
86    }
87
88    Output y;
89    switch (op_type) {
90      case ABS:
91        y = Abs(scope_, x);
92        break;
93      case NEG:
94        y = Neg(scope_, x);
95        break;
96      case INV:
97        y = Reciprocal(scope_, x);
98        break;
99      case SQUARE:
100        y = Square(scope_, x);
101        break;
102      case SQRT:
103        y = Sqrt(scope_, x);
104        break;
105      case RSQRT:
106        y = Rsqrt(scope_, x);
107        break;
108      case EXP:
109        y = Exp(scope_, x);
110        break;
111      case EXPM1:
112        y = Expm1(scope_, x);
113        break;
114      case LOG:
115        y = Log(scope_, x);
116        break;
117      case LOG1P:
118        y = Log1p(scope_, x);
119        break;
120      case SINH:
121        y = Sinh(scope_, x);
122        break;
123      case COSH:
124        y = Cosh(scope_, x);
125        break;
126      case TANH:
127        y = Tanh(scope_, x);
128        break;
129      case ASINH:
130        y = Asinh(scope_, x);
131        break;
132      case ACOSH:
133        y = Acosh(scope_, x);
134        break;
135      case ATANH:
136        y = Atanh(scope_, x);
137        break;
138      case SIGMOID:
139        y = Sigmoid(scope_, x);
140        break;
141      case SIGN:
142        y = Sign(scope_, x);
143        break;
144      case SIN:
145        y = Sin(scope_, x);
146        break;
147      case COS:
148        y = Cos(scope_, x);
149        break;
150      case ASIN:
151        y = Asin(scope_, x);
152        break;
153      case ACOS:
154        y = Acos(scope_, x);
155        break;
156      case TAN:
157        y = Tan(scope_, x);
158        break;
159      case ATAN:
160        y = Atan(scope_, x);
161        break;
162    }
163
164    std::vector<Output> grad_outputs;
165    TF_ASSERT_OK(test::CallGradFunction(
166        scope_, Operation(y.node()), {ops::Const(scope_, dy)}, &grad_outputs));
167    Tensor output;
168    test::GetTensor(scope_, grad_outputs[0], &output);
169    test::ExpectClose(output, dx);
170  }
171
172  float RV(const std::vector<float>& v) {
173    return v[random::New64() % v.size()];
174  }
175
176  complex64 CRV(const std::vector<complex64>& v) {
177    return v[random::New64() % v.size()];
178  }
179
180  complex64 conjugate(const complex64& val) {
181    return complex64(val.real(), -val.imag());
182  }
183
184  const complex64 one_{1.0, 0};
185
186  Scope scope_;
187};
188
189TEST_F(CWiseUnaryGradTest, Abs) {
190  auto x_fn = [this](const int i) { return RV({-1, 0, 1}); };
191  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
192  auto dx_fn = [this](const float x, const float dy) { return x * dy; };
193  TestCWiseGrad<float>(ABS, x_fn, dy_fn, dx_fn);
194}
195
196TEST_F(CWiseUnaryGradTest, Neg) {
197  auto x_fn = [this](const int i) { return RV({-1, 0, 1}); };
198  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
199  auto dx_fn = [this](const float x, const float dy) { return -dy; };
200  TestCWiseGrad<float>(NEG, x_fn, dy_fn, dx_fn);
201}
202
203TEST_F(CWiseUnaryGradTest, Reciprocal) {
204  auto x_fn = [this](const int i) { return RV({-1, 1, -2, 2, -3, 3, -4, 4}); };
205  auto dy_fn = [this](const float x) { return RV({0, -2, 2, -3, 3, -4, 4}); };
206  auto dx_fn = [this](const float x, const float dy) {
207    return -(1 / (x * x)) * dy;
208  };
209  TestCWiseGrad<float>(INV, x_fn, dy_fn, dx_fn);
210}
211
212TEST_F(CWiseUnaryGradTest, Reciprocal_Complex) {
213  auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); };
214  auto dy_fn = [this](const complex64 x) {
215    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
216  };
217  auto dx_fn = [this](const complex64 x, const complex64 dy) {
218    return -conjugate(one_ / (x * x)) * dy;
219  };
220  TestCWiseGrad<complex64>(INV, x_fn, dy_fn, dx_fn);
221}
222
223TEST_F(CWiseUnaryGradTest, Square) {
224  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
225  auto dy_fn = [this](const float x) { return RV({0, -7, 7, -8, 8, -9, 9}); };
226  auto dx_fn = [this](const float x, const float dy) { return 2 * x * dy; };
227  TestCWiseGrad<float>(SQUARE, x_fn, dy_fn, dx_fn);
228}
229
230TEST_F(CWiseUnaryGradTest, Square_Complex) {
231  auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); };
232  auto dy_fn = [this](const complex64& x) {
233    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
234  };
235  auto dx_fn = [this](const complex64& x, const complex64& dy) {
236    return conjugate(complex64(2, 0) * x) * dy;
237  };
238  TestCWiseGrad<complex64>(SQUARE, x_fn, dy_fn, dx_fn);
239}
240
241TEST_F(CWiseUnaryGradTest, Sqrt) {
242  auto x_fn = [this](const int i) { return RV({0, 1, 2, 3, 4, 5, 6, 7}); };
243  auto dy_fn = [this](const float x) { return x + RV({8, 9, 10, 11, 12, 13}); };
244  auto dx_fn = [this](const float x, const float dy) {
245    return dy * 0.5 * (1.0 / std::sqrt(x));
246  };
247  TestCWiseGrad<float>(SQRT, x_fn, dy_fn, dx_fn);
248}
249
250TEST_F(CWiseUnaryGradTest, Sqrt_Complex) {
251  auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); };
252  auto dy_fn = [this](const complex64& x) {
253    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
254  };
255  auto dx_fn = [this](const complex64& x, const complex64& dy) {
256    return conjugate(complex64(0.5, 0) / std::sqrt(x)) * dy;
257  };
258  TestCWiseGrad<complex64>(SQRT, x_fn, dy_fn, dx_fn);
259}
260
261TEST_F(CWiseUnaryGradTest, Rsqrt) {
262  auto x_fn = [this](const int i) { return RV({1, 2, 3, 4, 5, 6, 7, 8}); };
263  auto dy_fn = [this](const float x) { return x + RV({8, 9, 10, 11, 12, 13}); };
264  auto dx_fn = [this](const float x, const float dy) {
265    return dy * -0.5 * (1 / std::sqrt(x)) * (1 / x);
266  };
267  TestCWiseGrad<float>(RSQRT, x_fn, dy_fn, dx_fn);
268}
269
270TEST_F(CWiseUnaryGradTest, Rsqrt_Complex) {
271  auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); };
272  auto dy_fn = [this](const complex64& x) {
273    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
274  };
275  auto dx_fn = [this](const complex64& x, const complex64& dy) {
276    return conjugate(complex64(-0.5, 0) / std::sqrt(x) / x) * dy;
277  };
278  TestCWiseGrad<complex64>(RSQRT, x_fn, dy_fn, dx_fn);
279}
280
281TEST_F(CWiseUnaryGradTest, Exp) {
282  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
283  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
284  auto dx_fn = [this](const float x, const float dy) {
285    return dy * std::exp(x);
286  };
287  TestCWiseGrad<float>(EXP, x_fn, dy_fn, dx_fn);
288}
289
290TEST_F(CWiseUnaryGradTest, Exp_Complex) {
291  auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); };
292  auto dy_fn = [this](const complex64& x) {
293    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
294  };
295  auto dx_fn = [this](const complex64& x, const complex64& dy) {
296    return dy * conjugate(std::exp(x));
297  };
298  TestCWiseGrad<complex64>(EXP, x_fn, dy_fn, dx_fn);
299}
300
301TEST_F(CWiseUnaryGradTest, Expm1) {
302  auto x_fn = [this](const int i) { return RV({0, -1, 1e-6, 1, -2, 3, 100}); };
303  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
304  auto dx_fn = [this](const float x, const float dy) {
305    return dy * std::exp(x);
306  };
307  TestCWiseGrad<float>(EXPM1, x_fn, dy_fn, dx_fn);
308}
309
310TEST_F(CWiseUnaryGradTest, Expm1_Complex) {
311  auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); };
312  auto dy_fn = [this](const complex64& x) {
313    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
314  };
315  auto dx_fn = [this](const complex64& x, const complex64& dy) {
316    return dy * conjugate(std::exp(x));
317  };
318  TestCWiseGrad<complex64>(EXPM1, x_fn, dy_fn, dx_fn);
319}
320
321TEST_F(CWiseUnaryGradTest, Log) {
322  auto x_fn = [this](const int i) { return RV({-1, 1, -2, 2, -3, 3, -4, 4}); };
323  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
324  auto dx_fn = [this](const float x, const float dy) { return dy * (1.0 / x); };
325  TestCWiseGrad<float>(LOG, x_fn, dy_fn, dx_fn);
326}
327
328TEST_F(CWiseUnaryGradTest, Log_Complex) {
329  auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); };
330  auto dy_fn = [this](const complex64& x) {
331    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
332  };
333  auto dx_fn = [this](const complex64& x, const complex64& dy) {
334    return dy * conjugate(one_ / x);
335  };
336  TestCWiseGrad<complex64>(LOG, x_fn, dy_fn, dx_fn);
337}
338
339TEST_F(CWiseUnaryGradTest, Log1p) {
340  auto x_fn = [this](const int i) { return RV({0, 1e-6, 1, 2, 3, 4, 100}); };
341  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
342  auto dx_fn = [this](const float x, const float dy) {
343    return dy * (1.0 / (1.0 + x));
344  };
345  TestCWiseGrad<float>(LOG1P, x_fn, dy_fn, dx_fn);
346}
347
348TEST_F(CWiseUnaryGradTest, Log1p_Complex) {
349  auto x_fn = [this](const int i) {
350    return CRV({{0, 0}, {1e-6, 0}, {2, -1}, {1, 2}, {3, 4}});
351  };
352  auto dy_fn = [this](const complex64& x) {
353    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
354  };
355  auto dx_fn = [this](const complex64& x, const complex64& dy) {
356    return dy / (one_ + conjugate(x));
357  };
358  TestCWiseGrad<complex64>(LOG1P, x_fn, dy_fn, dx_fn);
359}
360
361TEST_F(CWiseUnaryGradTest, Sinh) {
362  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
363  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
364  auto dx_fn = [this](const float x, const float dy) {
365    return dy * std::cosh(x);
366  };
367  TestCWiseGrad<float>(SINH, x_fn, dy_fn, dx_fn);
368}
369
370TEST_F(CWiseUnaryGradTest, Sinh_Complex) {
371  auto x_fn = [this](const int i) {
372    return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
373  };
374  auto dy_fn = [this](const complex64& x) {
375    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
376  };
377  auto dx_fn = [this](const complex64& x, const complex64& dy) {
378    return dy * conjugate(std::cosh(x));
379  };
380  TestCWiseGrad<complex64>(SINH, x_fn, dy_fn, dx_fn);
381}
382
383TEST_F(CWiseUnaryGradTest, Cosh) {
384  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
385  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
386  auto dx_fn = [this](const float x, const float dy) {
387    return dy * std::sinh(x);
388  };
389  TestCWiseGrad<float>(COSH, x_fn, dy_fn, dx_fn);
390}
391
392TEST_F(CWiseUnaryGradTest, Cosh_Complex) {
393  auto x_fn = [this](const int i) {
394    return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
395  };
396  auto dy_fn = [this](const complex64& x) {
397    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
398  };
399  auto dx_fn = [this](const complex64& x, const complex64& dy) {
400    return dy * conjugate(std::sinh(x));
401  };
402  TestCWiseGrad<complex64>(COSH, x_fn, dy_fn, dx_fn);
403}
404
405TEST_F(CWiseUnaryGradTest, Tanh) {
406  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
407  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
408  auto dx_fn = [this](const float x, const float dy) {
409    const float y = std::tanh(x);
410    return dy * (1.0 - y * y);
411  };
412  TestCWiseGrad<float>(TANH, x_fn, dy_fn, dx_fn);
413}
414
415TEST_F(CWiseUnaryGradTest, Tanh_Complex) {
416  auto x_fn = [this](const int i) {
417    return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
418  };
419  auto dy_fn = [this](const complex64& x) {
420    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
421  };
422  auto dx_fn = [this](const complex64& x, const complex64& dy) {
423    const complex64 y = std::tanh(x);
424    return dy * conjugate((one_ - y * y));
425  };
426  TestCWiseGrad<complex64>(TANH, x_fn, dy_fn, dx_fn);
427}
428
429TEST_F(CWiseUnaryGradTest, Asinh) {
430  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
431  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
432  auto dx_fn = [this](const float x, const float dy) {
433    auto y = std::asinh(x);
434    return dy / std::cosh(y);
435  };
436  TestCWiseGrad<float>(ASINH, x_fn, dy_fn, dx_fn);
437}
438
439TEST_F(CWiseUnaryGradTest, Asinh_Complex) {
440  auto x_fn = [this](const int i) {
441    return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
442  };
443  auto dy_fn = [this](const complex64& x) {
444    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
445  };
446  auto dx_fn = [this](const complex64& x, const complex64& dy) {
447    auto y = std::asinh(x);
448    return dy / conjugate(std::cosh(y));
449  };
450  TestCWiseGrad<complex64>(ASINH, x_fn, dy_fn, dx_fn);
451}
452
453TEST_F(CWiseUnaryGradTest, Acosh) {
454  auto x_fn = [this](const int i) { return RV({1, 2, 3, 4, 5, 6, 7}); };
455  auto dy_fn = [this](const float x) {
456    return x + RV({8, 9, 10, 11, 12, 13, 14});
457  };
458  auto dx_fn = [this](const float x, const float dy) {
459    auto y = std::acosh(x);
460    return dy / std::sinh(y);
461  };
462  TestCWiseGrad<float>(ACOSH, x_fn, dy_fn, dx_fn);
463}
464
465TEST_F(CWiseUnaryGradTest, Acosh_Complex) {
466  auto x_fn = [this](const int i) {
467    return CRV({{1, 1}, {2, 1}, {1, 4}, {1, 2}, {3, 4}});
468  };
469  auto dy_fn = [this](const complex64& x) {
470    return x + CRV({{2, 2}, {3, 3}, {1, 4}});
471  };
472  auto dx_fn = [this](const complex64& x, const complex64& dy) {
473    auto y = std::acosh(x);
474    return dy / conjugate(std::sinh(y));
475  };
476  TestCWiseGrad<complex64>(ACOSH, x_fn, dy_fn, dx_fn);
477}
478
479TEST_F(CWiseUnaryGradTest, Atanh) {
480  auto x_fn = [this](const int i) { return RV({0, -0.5, 0.5, -0.1, 0.1}); };
481  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
482  auto dx_fn = [this](const float x, const float dy) {
483    return dy * (1. / (1. - x * x));
484  };
485  TestCWiseGrad<float>(ATANH, x_fn, dy_fn, dx_fn);
486}
487
488TEST_F(CWiseUnaryGradTest, Atanh_Complex) {
489  auto x_fn = [this](const int i) {
490    return CRV({{0.1, 0}, {0, 0.1}, {0.2, -0.1}, {0.1, 0.2}, {0.3, 0.4}});
491  };
492  auto dy_fn = [this](const complex64& x) {
493    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
494  };
495  auto dx_fn = [this](const complex64& x, const complex64& dy) {
496    return dy / conjugate(one_ - x * x);
497  };
498  TestCWiseGrad<complex64>(ATANH, x_fn, dy_fn, dx_fn);
499}
500
501TEST_F(CWiseUnaryGradTest, Sigmoid) {
502  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
503  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
504  auto dx_fn = [this](const float x, const float dy) {
505    const float y = 1.0 / (1.0 + std::exp(-x));
506    return dy * y * (1.0 - y);
507  };
508  TestCWiseGrad<float>(SIGMOID, x_fn, dy_fn, dx_fn);
509}
510
511TEST_F(CWiseUnaryGradTest, Sigmoid_Complex) {
512  auto x_fn = [this](const int i) {
513    return CRV({{1, 0}, {0, 0}, {2, -1}, {1, 2}, {3, 4}});
514  };
515  auto dy_fn = [this](const complex64& x) {
516    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
517  };
518  auto dx_fn = [this](const complex64& x, const complex64& dy) {
519    const complex64 y = one_ / (one_ + std::exp(-x));
520    return dy * conjugate(y * (one_ - y));
521  };
522  TestCWiseGrad<complex64>(SIGMOID, x_fn, dy_fn, dx_fn);
523}
524
525TEST_F(CWiseUnaryGradTest, Sign) {
526  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
527  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
528  auto dx_fn = [this](const float x, const float dy) { return 0.0; };
529  TestCWiseGrad<float>(SIGN, x_fn, dy_fn, dx_fn);
530}
531
532TEST_F(CWiseUnaryGradTest, Sin) {
533  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
534  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
535  auto dx_fn = [this](const float x, const float dy) {
536    return dy * std::cos(x);
537  };
538  TestCWiseGrad<float>(SIN, x_fn, dy_fn, dx_fn);
539}
540
541TEST_F(CWiseUnaryGradTest, Sin_Complex) {
542  auto x_fn = [this](const int i) {
543    return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
544  };
545  auto dy_fn = [this](const complex64& x) {
546    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
547  };
548  auto dx_fn = [this](const complex64& x, const complex64& dy) {
549    return dy * conjugate(std::cos(x));
550  };
551  TestCWiseGrad<complex64>(SIN, x_fn, dy_fn, dx_fn);
552}
553
554TEST_F(CWiseUnaryGradTest, Cos) {
555  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
556  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
557  auto dx_fn = [this](const float x, const float dy) {
558    return dy * -1.0 * std::sin(x);
559  };
560  TestCWiseGrad<float>(COS, x_fn, dy_fn, dx_fn);
561}
562
563TEST_F(CWiseUnaryGradTest, Cos_Complex) {
564  auto x_fn = [this](const int i) {
565    return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
566  };
567  auto dy_fn = [this](const complex64& x) {
568    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
569  };
570  auto dx_fn = [this](const complex64& x, const complex64& dy) {
571    return dy * conjugate(-std::sin(x));
572  };
573  TestCWiseGrad<complex64>(COS, x_fn, dy_fn, dx_fn);
574}
575
576TEST_F(CWiseUnaryGradTest, Asin) {
577  auto x_fn = [this](const int i) { return RV({0, -0.5, 0.5, -1, 1}); };
578  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
579  auto dx_fn = [this](const float x, const float dy) {
580    return dy * (1.0 / std::sqrt(1.0 - x * x));
581  };
582  TestCWiseGrad<float>(ASIN, x_fn, dy_fn, dx_fn);
583}
584
585TEST_F(CWiseUnaryGradTest, Asin_Complex) {
586  auto x_fn = [this](const int i) {
587    return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
588  };
589  auto dy_fn = [this](const complex64& x) {
590    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
591  };
592  auto dx_fn = [this](const complex64& x, const complex64& dy) {
593    return dy / conjugate(std::sqrt(one_ - x * x));
594  };
595  // TODO(kbsriram)
596  // Enable test when the asin kernel supports complex numbers
597  if (false) {
598    TestCWiseGrad<complex64>(ASIN, x_fn, dy_fn, dx_fn);
599  }
600}
601
602TEST_F(CWiseUnaryGradTest, Acos) {
603  auto x_fn = [this](const int i) { return RV({0, -0.5, 0.5, -1, 1}); };
604  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
605  auto dx_fn = [this](const float x, const float dy) {
606    return dy * (-1.0 / std::sqrt(1.0 - x * x));
607  };
608  TestCWiseGrad<float>(ACOS, x_fn, dy_fn, dx_fn);
609}
610
611TEST_F(CWiseUnaryGradTest, Acos_Complex) {
612  auto x_fn = [this](const int i) {
613    return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
614  };
615  auto dy_fn = [this](const complex64& x) {
616    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
617  };
618  auto dx_fn = [this](const complex64& x, const complex64& dy) {
619    return dy / -conjugate(std::sqrt(one_ - x * x));
620  };
621  // TODO(kbsriram)
622  // Add test when the acos kernel supports complex numbers
623  if (false) {
624    TestCWiseGrad<complex64>(ACOS, x_fn, dy_fn, dx_fn);
625  }
626}
627
628TEST_F(CWiseUnaryGradTest, Tan) {
629  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
630  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
631  auto dx_fn = [this](const float x, const float dy) {
632    const float cosx = std::cos(x);
633    return dy * (1 / (cosx * cosx));
634  };
635  TestCWiseGrad<float>(TAN, x_fn, dy_fn, dx_fn);
636}
637
638TEST_F(CWiseUnaryGradTest, Tan_Complex) {
639  auto x_fn = [this](const int i) {
640    return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
641  };
642  auto dy_fn = [this](const complex64& x) {
643    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
644  };
645  auto dx_fn = [this](const complex64& x, const complex64& dy) {
646    const complex64 cosx = std::cos(x);
647    return dy / conjugate(cosx * cosx);
648  };
649  // TODO(kbsriram)
650  // Enable when tan kernel supports complex inputs
651  if (false) {
652    TestCWiseGrad<complex64>(TAN, x_fn, dy_fn, dx_fn);
653  }
654}
655
656TEST_F(CWiseUnaryGradTest, Atan) {
657  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
658  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
659  auto dx_fn = [this](const float x, const float dy) {
660    return dy * (1 / (1 + x * x));
661  };
662  TestCWiseGrad<float>(ATAN, x_fn, dy_fn, dx_fn);
663}
664
665TEST_F(CWiseUnaryGradTest, Atan_Complex) {
666  auto x_fn = [this](const int i) {
667    return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
668  };
669  auto dy_fn = [this](const complex64& x) {
670    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
671  };
672  auto dx_fn = [this](const complex64& x, const complex64& dy) {
673    return dy / (one_ + x * x);
674  };
675  // TODO(kbsriram)
676  // Add test when the atan kernel supports complex numbers
677  if (false) {
678    TestCWiseGrad<complex64>(ATAN, x_fn, dy_fn, dx_fn);
679  }
680}
681
682class CWiseUnaryComplexGradTest : public ::testing::Test {
683 protected:
684  CWiseUnaryComplexGradTest()
685      : scope_(Scope::NewRootScope().WithDevice("/cpu:0")) {}
686
687  enum UnaryOpType { REAL, IMAG, CONJ };
688
689  void TestCWiseGradComplex(UnaryOpType op_type, const Tensor& x,
690                            const Tensor& dy, const Tensor& dx_expected) {
691    Output y;
692    switch (op_type) {
693      case REAL:
694        y = Real(scope_, x);
695        break;
696      case IMAG:
697        y = Imag(scope_, x);
698        break;
699      case CONJ:
700        y = Conj(scope_, x);
701        break;
702    }
703
704    std::vector<Output> grad_outputs;
705    TF_ASSERT_OK(test::CallGradFunction(
706        scope_, Operation(y.node()), {ops::Const(scope_, dy)}, &grad_outputs));
707    Tensor dx;
708    test::GetTensor(scope_, grad_outputs[0], &dx);
709    test::ExpectClose(dx, dx_expected);
710  }
711
712  Scope scope_;
713};
714
715TEST_F(CWiseUnaryComplexGradTest, Real) {
716  Tensor x = test::AsTensor<complex64>(
717      {{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3});
718  Tensor dy = test::AsTensor<float>({11, -12, 13, -14, 15, -16}, {2, 3});
719  Tensor dx_expected = test::AsTensor<complex64>(
720      {{11, 0}, {-12, 0}, {13, 0}, {-14, 0}, {15, 0}, {-16, 0}}, {2, 3});
721  TestCWiseGradComplex(REAL, x, dy, dx_expected);
722}
723
724TEST_F(CWiseUnaryComplexGradTest, Imag) {
725  Tensor x = test::AsTensor<complex64>(
726      {{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3});
727  Tensor dy = test::AsTensor<float>({11, -12, 13, -14, 15, -16}, {2, 3});
728  Tensor dx_expected = test::AsTensor<complex64>(
729      {{0, 11}, {0, -12}, {0, 13}, {0, -14}, {0, 15}, {0, -16}}, {2, 3});
730  TestCWiseGradComplex(IMAG, x, dy, dx_expected);
731}
732
733TEST_F(CWiseUnaryComplexGradTest, Conj) {
734  Tensor x = test::AsTensor<complex64>(
735      {{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3});
736  Tensor dy = test::AsTensor<complex64>(
737      {{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3});
738  Tensor dx_expected = test::AsTensor<complex64>(
739      {{1, 1}, {-2, -2}, {3, 3}, {-4, -4}, {8, 8}, {-9, -9}}, {2, 3});
740  TestCWiseGradComplex(CONJ, x, dy, dx_expected);
741}
742
743class MathGradTest : public ::testing::Test {
744 protected:
745  MathGradTest() : root_(Scope::NewRootScope().WithDevice("/cpu:0")) {}
746
747  void TestMatMulGrad(const bool is_batch, const bool t_x, const bool t_y) {
748    // Generate random test data.
749    std::vector<Tensor> data;
750    RandMatMulGradData(is_batch, t_x, t_y, &data);
751    auto x = Const(root_, data[0]);
752    auto y = Const(root_, data[1]);
753    auto dz = Const(root_, data[2]);
754
755    std::vector<Tensor> grad_outputs;
756    ComputeMatMulGrad(is_batch, x, t_x, y, t_y, dz, &grad_outputs);
757
758    if (!t_x && !t_y) {
759      test::ExpectClose(grad_outputs[0],
760                        ComputeMatMul(is_batch, dz, false, y, true));
761      test::ExpectClose(grad_outputs[1],
762                        ComputeMatMul(is_batch, x, true, dz, false));
763    } else if (t_x && !t_y) {
764      test::ExpectClose(grad_outputs[0],
765                        ComputeMatMul(is_batch, y, false, dz, true));
766      test::ExpectClose(grad_outputs[1],
767                        ComputeMatMul(is_batch, x, false, dz, false));
768    } else if (!t_x && t_y) {
769      test::ExpectClose(grad_outputs[0],
770                        ComputeMatMul(is_batch, dz, false, y, false));
771      test::ExpectClose(grad_outputs[1],
772                        ComputeMatMul(is_batch, dz, true, x, false));
773    } else {
774      test::ExpectClose(grad_outputs[0],
775                        ComputeMatMul(is_batch, y, true, dz, true));
776      test::ExpectClose(grad_outputs[1],
777                        ComputeMatMul(is_batch, dz, true, x, true));
778    }
779  }
780
781  void ComputeMatMulGrad(const bool is_batch, const Output& x, const bool t_x,
782                         const Output& y, const bool t_y, const Output& dz,
783                         std::vector<Tensor>* out) {
784    // Compute forward MatMul: z = MatMul(x, y).
785    Output z;
786    if (is_batch) {
787      z = BatchMatMul(root_, x, y, BatchMatMul::AdjX(t_x).AdjY(t_y));
788    } else {
789      z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y));
790    }
791    TF_ASSERT_OK(root_.status());
792    CHECK_NOTNULL(z.node());
793    std::vector<Output> grad_outputs;
794    // Call MatMulGrad which populates 'grad_outputs'.
795    TF_ASSERT_OK(test::CallGradFunction(root_, Operation(z.node()), {dz},
796                                        &grad_outputs));
797    ASSERT_EQ(2, grad_outputs.size());
798    // Run graph and return MatMul gradient tensors for 'dx' and 'dy' in 'out'.
799    test::GetTensors(root_, {grad_outputs[0], grad_outputs[1]}, out);
800  }
801
802  Tensor ComputeMatMul(const bool is_batch, const Output& x, const bool t_x,
803                       const Output& y, const bool t_y) {
804    Output z;
805    if (is_batch) {
806      z = BatchMatMul(root_, x, y, BatchMatMul::AdjX(t_x).AdjY(t_y));
807    } else {
808      z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y));
809    }
810    TF_EXPECT_OK(root_.status());
811    Tensor out;
812    test::GetTensor(root_, z, &out);
813    return out;
814  }
815
816  void RandMatMulGradData(const bool is_batch, const bool tx, const bool ty,
817                          std::vector<Tensor>* data) {
818    // Choose a random batch size in [1, 4]
819    const int b = 1 + (random::New64() % 4);
820    // z = MatMul(x, y)
821    const int m = Rand();
822    const int k = Rand();
823    const int n = Rand();
824
825    TensorShape x_shape;
826    if (is_batch) {
827      // x.shape = [b, m, k]
828      x_shape = tx ? TensorShape({b, k, m}) : TensorShape({b, m, k});
829    } else {
830      // x.shape = [m, k]
831      x_shape = tx ? TensorShape({k, m}) : TensorShape({m, k});
832    }
833    data->emplace_back(DT_FLOAT, x_shape);
834    RandTensor(&data->back());
835
836    TensorShape y_shape;
837    if (is_batch) {
838      // y.shape = [b, k, n]
839      y_shape = ty ? TensorShape({b, n, k}) : TensorShape({b, k, n});
840    } else {
841      // y.shape = [k, n]
842      y_shape = ty ? TensorShape({n, k}) : TensorShape({k, n});
843    }
844    data->emplace_back(DT_FLOAT, y_shape);
845    RandTensor(&data->back());
846
847    TensorShape z_shape;
848    if (is_batch) {
849      // z.shape = [b, m, n]
850      z_shape = TensorShape({b, m, n});
851    } else {
852      // z.shape = [m, n]
853      z_shape = TensorShape({m, n});
854    }
855    data->emplace_back(DT_FLOAT, z_shape);
856    RandTensor(&data->back());
857  }
858
859  void RandTensor(Tensor* t) {
860    test::FillFn<float>(
861        t, [this](const int i) { return static_cast<float>(Rand()); });
862  }
863
864  int Rand() { return 1 + (random::New64() % 10); }
865
866  Scope root_;
867};
868
869TEST_F(MathGradTest, MatMulGrad_NoTranspose) {
870  TestMatMulGrad(false, false, false);
871}
872
873TEST_F(MathGradTest, MatMulGrad_TransposeX) {
874  TestMatMulGrad(false, true, false);
875}
876
877TEST_F(MathGradTest, MatMulGrad_TransposeY) {
878  TestMatMulGrad(false, false, true);
879}
880
881TEST_F(MathGradTest, MatMulGrad_TransposeX_TransposeY) {
882  TestMatMulGrad(false, true, true);
883}
884
885TEST_F(MathGradTest, BatchMatMulGrad_NoTranspose) {
886  TestMatMulGrad(true, false, false);
887}
888
889TEST_F(MathGradTest, BatchMatMulGrad_TransposeX) {
890  TestMatMulGrad(true, true, false);
891}
892
893TEST_F(MathGradTest, BatchMatMulGrad_TransposeY) {
894  TestMatMulGrad(true, false, true);
895}
896
897TEST_F(MathGradTest, BatchMatMulGrad_TransposeX_TransposeY) {
898  TestMatMulGrad(true, true, true);
899}
900
901class NaryGradTest : public ::testing::Test {
902 protected:
903  NaryGradTest() : scope_(Scope::NewRootScope().WithDevice("/cpu:0")) {}
904
905  void RunTest(const OutputList& xs, const std::vector<TensorShape>& x_shapes,
906               const OutputList& ys, const std::vector<TensorShape>& y_shapes) {
907    TF_ASSERT_OK(scope_.status());
908    float max_error;
909    TF_ASSERT_OK(
910        ComputeGradientError(scope_, xs, x_shapes, ys, y_shapes, &max_error));
911    EXPECT_LT(max_error, 1e-3);
912  }
913
914  void RunTest(const Output& x, const Tensor& x_init_value, const Output& y,
915               const TensorShape& y_shape) {
916    TF_ASSERT_OK(scope_.status());
917    float max_error;
918    TF_ASSERT_OK(
919        ComputeGradientError(scope_, x, x_init_value, y, y_shape, &max_error));
920    EXPECT_LT(max_error, 1e-3);
921  }
922
923  Scope scope_;
924};
925
926TEST_F(NaryGradTest, AddN) {
927  TensorShape shape({3, 2, 5});
928  std::vector<Output> xs;
929  xs.push_back(Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)));
930  xs.push_back(Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)));
931  xs.push_back(Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)));
932  auto y = AddN(scope_, xs);
933  RunTest(xs, {shape, shape, shape}, {y}, {shape});
934}
935
936TEST_F(NaryGradTest, Add) {
937  TensorShape x1_shape({3, 2, 5});
938  TensorShape x2_shape({2, 5});
939  auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x1_shape));
940  auto x2 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x2_shape));
941  auto y = Add(scope_, x1, x2);
942  RunTest({x1, x2}, {x1_shape, x2_shape}, {y}, {x1_shape});
943}
944
945TEST_F(NaryGradTest, Sub) {
946  TensorShape x1_shape({3, 2, 5});
947  TensorShape x2_shape({2, 5});
948  auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x1_shape));
949  auto x2 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x2_shape));
950  auto y = Sub(scope_, x1, x2);
951  RunTest({x1, x2}, {x1_shape, x2_shape}, {y}, {x1_shape});
952}
953
954TEST_F(NaryGradTest, Mul) {
955  TensorShape x1_shape({3, 2, 5});
956  TensorShape x2_shape({2, 5});
957  auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x1_shape));
958  auto x2 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x2_shape));
959  auto y = Mul(scope_, x1, x2);
960  RunTest({x1, x2}, {x1_shape, x2_shape}, {y}, {x1_shape});
961}
962
963TEST_F(NaryGradTest, Div) {
964  TensorShape x_shape({3, 2, 5});
965  auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
966  // Test x / (1 + |x|) rather than x_1 / x_2 to avoid triggering large
967  // division errors in the numeric estimator used by the gradient checker.
968  auto y = Div(scope_, x, Add(scope_, Const<float>(scope_, 1), Abs(scope_, x)));
969  RunTest({x}, {x_shape}, {y}, {x_shape});
970}
971
972TEST_F(NaryGradTest, RealDiv) {
973  TensorShape x_shape({3, 2, 5});
974  auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
975  // Test x / (1 + |x|) rather than x_1 / x_2 to avoid triggering large
976  // division errors in the numeric estimator used by the gradient checker.
977  auto y =
978      RealDiv(scope_, x, Add(scope_, Const<float>(scope_, 1), Abs(scope_, x)));
979  RunTest({x}, {x_shape}, {y}, {x_shape});
980}
981
982TEST_F(NaryGradTest, SquaredDifference) {
983  TensorShape x1_shape({3, 2, 5});
984  TensorShape x2_shape({2, 5});
985  auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x1_shape));
986  auto x2 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x2_shape));
987  auto y = SquaredDifference(scope_, x1, x2);
988  RunTest({x1, x2}, {x1_shape, x2_shape}, {y}, {x1_shape});
989}
990
991TEST_F(NaryGradTest, Maximum) {
992  TensorShape shape({3, 2});
993  auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
994  auto y = Maximum(scope_, x, Const(scope_, 1.0f));
995  // Select values away from 1.0f to avoid instability when computing
996  // finite differences.
997  Tensor x_init_value =
998      test::AsTensor<float>({0.5f, 1.5f, -1.2f, 3.0f, 0.1f, 2.8f}, {3, 2});
999  RunTest(x, x_init_value, y, shape);
1000}
1001
1002TEST_F(NaryGradTest, Minimum) {
1003  TensorShape shape({3, 2});
1004  auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
1005  auto y = Minimum(scope_, x, Const(scope_, 1.0f));
1006  // Select values away from 1.0f to avoid instability when computing
1007  // finite differences.
1008  Tensor x_init_value =
1009      test::AsTensor<float>({0.5f, 1.5f, -1.2f, 3.0f, 0.1f, 2.8f}, {3, 2});
1010  RunTest(x, x_init_value, y, shape);
1011}
1012
1013}  // namespace
1014}  // namespace tensorflow
1015