math_grad_test.cc revision 50b999a8336d19400ab75aea66fe46eca2f5fe0b
1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/cc/framework/grad_op_registry.h"
17#include "tensorflow/cc/framework/testutil.h"
18#include "tensorflow/cc/gradients/grad_testutil.h"
19#include "tensorflow/cc/ops/standard_ops.h"
20#include "tensorflow/core/framework/tensor_testutil.h"
21#include "tensorflow/core/lib/core/status_test_util.h"
22#include "tensorflow/core/lib/random/random.h"
23
24namespace tensorflow {
25using namespace ops;  // NOLINT(build/namespaces)
26
27namespace {
28
29// TODO(andydavis) Test gradient function against numeric gradients output.
30// TODO(andydavis) As more gradients are added move common test functions
31// to a testutil library.
32
33class CWiseUnaryGradTest : public ::testing::Test {
34 protected:
35  CWiseUnaryGradTest() : scope_(Scope::NewRootScope().WithDevice("/cpu:0")) {}
36
37  enum UnaryOpType {
38    ABS,
39    NEG,
40    INV,
41    SQUARE,
42    SQRT,
43    RSQRT,
44    EXP,
45    EXPM1,
46    LOG,
47    LOG1P,
48    SINH,
49    COSH,
50    TANH,
51    SIGMOID,
52    SIGN,
53    SIN,
54    COS,
55    ASIN,
56    ACOS,
57    TAN,
58    ATAN
59  };
60
61  template <typename T>
62  void TestCWiseGrad(UnaryOpType op_type, const std::function<T(int)>& x_fn,
63                     const std::function<T(const T&)>& dy_fn,
64                     const std::function<T(const T&, const T&)>& dx_fn) {
65    DataType dtype = DataTypeToEnum<T>::v();
66    Tensor x(dtype, {2, 3, 2});
67    auto x_flat = x.flat<T>();
68    for (int i = 0; i < x_flat.size(); ++i) {
69      x_flat(i) = x_fn(i);
70    }
71
72    Tensor dy(dtype, {2, 3, 2});
73    auto dy_flat = dy.flat<T>();
74    for (int i = 0; i < dy_flat.size(); ++i) {
75      dy_flat(i) = dy_fn(x_flat(i));
76    }
77
78    Tensor dx(dtype, {2, 3, 2});
79    auto dx_flat = dx.flat<T>();
80    for (int i = 0; i < dx_flat.size(); ++i) {
81      dx_flat(i) = dx_fn(x_flat(i), dy_flat(i));
82    }
83
84    Output y;
85    switch (op_type) {
86      case ABS:
87        y = Abs(scope_, x);
88        break;
89      case NEG:
90        y = Neg(scope_, x);
91        break;
92      case INV:
93        y = Reciprocal(scope_, x);
94        break;
95      case SQUARE:
96        y = Square(scope_, x);
97        break;
98      case SQRT:
99        y = Sqrt(scope_, x);
100        break;
101      case RSQRT:
102        y = Rsqrt(scope_, x);
103        break;
104      case EXP:
105        y = Exp(scope_, x);
106        break;
107      case EXPM1:
108        y = Expm1(scope_, x);
109        break;
110      case LOG:
111        y = Log(scope_, x);
112        break;
113      case LOG1P:
114        y = Log1p(scope_, x);
115        break;
116      case SINH:
117        y = Sinh(scope_, x);
118        break;
119      case COSH:
120        y = Cosh(scope_, x);
121        break;
122      case TANH:
123        y = Tanh(scope_, x);
124        break;
125      case SIGMOID:
126        y = Sigmoid(scope_, x);
127        break;
128      case SIGN:
129        y = Sign(scope_, x);
130        break;
131      case SIN:
132        y = Sin(scope_, x);
133        break;
134      case COS:
135        y = Cos(scope_, x);
136        break;
137      case ASIN:
138        y = Asin(scope_, x);
139        break;
140      case ACOS:
141        y = Acos(scope_, x);
142        break;
143      case TAN:
144        y = Tan(scope_, x);
145        break;
146      case ATAN:
147        y = Atan(scope_, x);
148        break;
149    }
150
151    std::vector<Output> grad_outputs;
152    TF_ASSERT_OK(test::CallGradFunction(
153        scope_, Operation(y.node()), {ops::Const(scope_, dy)}, &grad_outputs));
154    Tensor output;
155    test::GetTensor(scope_, grad_outputs[0], &output);
156    test::ExpectClose(output, dx);
157  }
158
159  float RV(const std::vector<float>& v) {
160    return v[random::New64() % v.size()];
161  }
162
163  complex64 CRV(const std::vector<complex64>& v) {
164    return v[random::New64() % v.size()];
165  }
166
167  complex64 conjugate(const complex64& val) {
168    return complex64(val.real(), -val.imag());
169  }
170
171  const complex64 one_{1.0, 0};
172
173  Scope scope_;
174};
175
176TEST_F(CWiseUnaryGradTest, Abs) {
177  auto x_fn = [this](const int i) { return RV({-1, 0, 1}); };
178  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
179  auto dx_fn = [this](const float x, const float dy) { return x * dy; };
180  TestCWiseGrad<float>(ABS, x_fn, dy_fn, dx_fn);
181}
182
183TEST_F(CWiseUnaryGradTest, Neg) {
184  auto x_fn = [this](const int i) { return RV({-1, 0, 1}); };
185  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
186  auto dx_fn = [this](const float x, const float dy) { return -dy; };
187  TestCWiseGrad<float>(NEG, x_fn, dy_fn, dx_fn);
188}
189
190TEST_F(CWiseUnaryGradTest, Reciprocal) {
191  auto x_fn = [this](const int i) { return RV({-1, 1, -2, 2, -3, 3, -4, 4}); };
192  auto dy_fn = [this](const float x) { return RV({0, -2, 2, -3, 3, -4, 4}); };
193  auto dx_fn = [this](const float x, const float dy) {
194    return -(1 / (x * x)) * dy;
195  };
196  TestCWiseGrad<float>(INV, x_fn, dy_fn, dx_fn);
197}
198
199TEST_F(CWiseUnaryGradTest, Reciprocal_Complex) {
200  auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); };
201  auto dy_fn = [this](const complex64 x) {
202    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
203  };
204  auto dx_fn = [this](const complex64 x, const complex64 dy) {
205    return -conjugate(one_ / (x * x)) * dy;
206  };
207  TestCWiseGrad<complex64>(INV, x_fn, dy_fn, dx_fn);
208}
209
210TEST_F(CWiseUnaryGradTest, Square) {
211  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
212  auto dy_fn = [this](const float x) { return RV({0, -7, 7, -8, 8, -9, 9}); };
213  auto dx_fn = [this](const float x, const float dy) { return 2 * x * dy; };
214  TestCWiseGrad<float>(SQUARE, x_fn, dy_fn, dx_fn);
215}
216
217TEST_F(CWiseUnaryGradTest, Square_Complex) {
218  auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); };
219  auto dy_fn = [this](const complex64& x) {
220    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
221  };
222  auto dx_fn = [this](const complex64& x, const complex64& dy) {
223    return conjugate(complex64(2, 0) * x) * dy;
224  };
225  TestCWiseGrad<complex64>(SQUARE, x_fn, dy_fn, dx_fn);
226}
227
228TEST_F(CWiseUnaryGradTest, Sqrt) {
229  auto x_fn = [this](const int i) { return RV({0, 1, 2, 3, 4, 5, 6, 7}); };
230  auto dy_fn = [this](const float x) { return x + RV({8, 9, 10, 11, 12, 13}); };
231  auto dx_fn = [this](const float x, const float dy) {
232    return dy * 0.5 * (1.0 / std::sqrt(x));
233  };
234  TestCWiseGrad<float>(SQRT, x_fn, dy_fn, dx_fn);
235}
236
237TEST_F(CWiseUnaryGradTest, Sqrt_Complex) {
238  auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); };
239  auto dy_fn = [this](const complex64& x) {
240    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
241  };
242  auto dx_fn = [this](const complex64& x, const complex64& dy) {
243    return conjugate(complex64(0.5, 0) / std::sqrt(x)) * dy;
244  };
245  TestCWiseGrad<complex64>(SQRT, x_fn, dy_fn, dx_fn);
246}
247
248TEST_F(CWiseUnaryGradTest, Rsqrt) {
249  auto x_fn = [this](const int i) { return RV({1, 2, 3, 4, 5, 6, 7, 8}); };
250  auto dy_fn = [this](const float x) { return x + RV({8, 9, 10, 11, 12, 13}); };
251  auto dx_fn = [this](const float x, const float dy) {
252    return dy * -0.5 * (1 / std::sqrt(x)) * (1 / x);
253  };
254  TestCWiseGrad<float>(RSQRT, x_fn, dy_fn, dx_fn);
255}
256
257TEST_F(CWiseUnaryGradTest, Rsqrt_Complex) {
258  auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); };
259  auto dy_fn = [this](const complex64& x) {
260    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
261  };
262  auto dx_fn = [this](const complex64& x, const complex64& dy) {
263    return conjugate(complex64(-0.5, 0) / std::sqrt(x) / x) * dy;
264  };
265  TestCWiseGrad<complex64>(RSQRT, x_fn, dy_fn, dx_fn);
266}
267
268TEST_F(CWiseUnaryGradTest, Exp) {
269  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
270  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
271  auto dx_fn = [this](const float x, const float dy) {
272    return dy * std::exp(x);
273  };
274  TestCWiseGrad<float>(EXP, x_fn, dy_fn, dx_fn);
275}
276
277TEST_F(CWiseUnaryGradTest, Exp_Complex) {
278  auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); };
279  auto dy_fn = [this](const complex64& x) {
280    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
281  };
282  auto dx_fn = [this](const complex64& x, const complex64& dy) {
283    return dy * conjugate(std::exp(x));
284  };
285  TestCWiseGrad<complex64>(EXP, x_fn, dy_fn, dx_fn);
286}
287
288TEST_F(CWiseUnaryGradTest, Expm1) {
289  auto x_fn = [this](const int i) { return RV({0, -1, 1e-6, 1, -2, 3, 100}); };
290  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
291  auto dx_fn = [this](const float x, const float dy) {
292    return dy * std::exp(x);
293  };
294  TestCWiseGrad<float>(EXPM1, x_fn, dy_fn, dx_fn);
295}
296
297TEST_F(CWiseUnaryGradTest, Expm1_Complex) {
298  auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); };
299  auto dy_fn = [this](const complex64& x) {
300    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
301  };
302  auto dx_fn = [this](const complex64& x, const complex64& dy) {
303    return dy * conjugate(std::exp(x));
304  };
305  TestCWiseGrad<complex64>(EXPM1, x_fn, dy_fn, dx_fn);
306}
307
308TEST_F(CWiseUnaryGradTest, Log) {
309  auto x_fn = [this](const int i) { return RV({-1, 1, -2, 2, -3, 3, -4, 4}); };
310  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
311  auto dx_fn = [this](const float x, const float dy) { return dy * (1.0 / x); };
312  TestCWiseGrad<float>(LOG, x_fn, dy_fn, dx_fn);
313}
314
315TEST_F(CWiseUnaryGradTest, Log_Complex) {
316  auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); };
317  auto dy_fn = [this](const complex64& x) {
318    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
319  };
320  auto dx_fn = [this](const complex64& x, const complex64& dy) {
321    return dy * conjugate(one_ / x);
322  };
323  TestCWiseGrad<complex64>(LOG, x_fn, dy_fn, dx_fn);
324}
325
326TEST_F(CWiseUnaryGradTest, Log1p) {
327  auto x_fn = [this](const int i) { return RV({0, 1e-6, 1, 2, 3, 4, 100}); };
328  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
329  auto dx_fn = [this](const float x, const float dy) {
330    return dy * (1.0 / (1.0 + x));
331  };
332  TestCWiseGrad<float>(LOG1P, x_fn, dy_fn, dx_fn);
333}
334
335TEST_F(CWiseUnaryGradTest, Log1p_Complex) {
336  auto x_fn = [this](const int i) {
337    return CRV({{0, 0}, {1e-6, 0}, {2, -1}, {1, 2}, {3, 4}});
338  };
339  auto dy_fn = [this](const complex64& x) {
340    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
341  };
342  auto dx_fn = [this](const complex64& x, const complex64& dy) {
343    return dy / (one_ + conjugate(x));
344  };
345  TestCWiseGrad<complex64>(LOG1P, x_fn, dy_fn, dx_fn);
346}
347
348TEST_F(CWiseUnaryGradTest, Sinh) {
349  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
350  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
351  auto dx_fn = [this](const float x, const float dy) {
352    return dy * std::cosh(x);
353  };
354  TestCWiseGrad<float>(SINH, x_fn, dy_fn, dx_fn);
355}
356
357TEST_F(CWiseUnaryGradTest, Sinh_Complex) {
358  auto x_fn = [this](const int i) {
359    return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
360  };
361  auto dy_fn = [this](const complex64& x) {
362    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
363  };
364  auto dx_fn = [this](const complex64& x, const complex64& dy) {
365    return dy * conjugate(std::cosh(x));
366  };
367  TestCWiseGrad<complex64>(SINH, x_fn, dy_fn, dx_fn);
368}
369
370TEST_F(CWiseUnaryGradTest, Cosh) {
371  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
372  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
373  auto dx_fn = [this](const float x, const float dy) {
374    return dy * std::sinh(x);
375  };
376  TestCWiseGrad<float>(COSH, x_fn, dy_fn, dx_fn);
377}
378
379TEST_F(CWiseUnaryGradTest, Cosh_Complex) {
380  auto x_fn = [this](const int i) {
381    return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
382  };
383  auto dy_fn = [this](const complex64& x) {
384    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
385  };
386  auto dx_fn = [this](const complex64& x, const complex64& dy) {
387    return dy * conjugate(std::sinh(x));
388  };
389  TestCWiseGrad<complex64>(COSH, x_fn, dy_fn, dx_fn);
390}
391
392TEST_F(CWiseUnaryGradTest, Tanh) {
393  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
394  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
395  auto dx_fn = [this](const float x, const float dy) {
396    const float y = std::tanh(x);
397    return dy * (1.0 - y * y);
398  };
399  TestCWiseGrad<float>(TANH, x_fn, dy_fn, dx_fn);
400}
401
402TEST_F(CWiseUnaryGradTest, Tanh_Complex) {
403  auto x_fn = [this](const int i) {
404    return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
405  };
406  auto dy_fn = [this](const complex64& x) {
407    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
408  };
409  auto dx_fn = [this](const complex64& x, const complex64& dy) {
410    const complex64 y = std::tanh(x);
411    return dy * conjugate((one_ - y * y));
412  };
413  TestCWiseGrad<complex64>(TANH, x_fn, dy_fn, dx_fn);
414}
415
416TEST_F(CWiseUnaryGradTest, Sigmoid) {
417  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
418  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
419  auto dx_fn = [this](const float x, const float dy) {
420    const float y = 1.0 / (1.0 + std::exp(-x));
421    return dy * y * (1.0 - y);
422  };
423  TestCWiseGrad<float>(SIGMOID, x_fn, dy_fn, dx_fn);
424}
425
426TEST_F(CWiseUnaryGradTest, Sigmoid_Complex) {
427  auto x_fn = [this](const int i) {
428    return CRV({{1, 0}, {0, 0}, {2, -1}, {1, 2}, {3, 4}});
429  };
430  auto dy_fn = [this](const complex64& x) {
431    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
432  };
433  auto dx_fn = [this](const complex64& x, const complex64& dy) {
434    const complex64 y = one_ / (one_ + std::exp(-x));
435    return dy * conjugate(y * (one_ - y));
436  };
437  TestCWiseGrad<complex64>(SIGMOID, x_fn, dy_fn, dx_fn);
438}
439
440TEST_F(CWiseUnaryGradTest, Sign) {
441  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
442  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
443  auto dx_fn = [this](const float x, const float dy) { return 0.0; };
444  TestCWiseGrad<float>(SIGN, x_fn, dy_fn, dx_fn);
445}
446
447TEST_F(CWiseUnaryGradTest, Sin) {
448  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
449  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
450  auto dx_fn = [this](const float x, const float dy) {
451    return dy * std::cos(x);
452  };
453  TestCWiseGrad<float>(SIN, x_fn, dy_fn, dx_fn);
454}
455
456TEST_F(CWiseUnaryGradTest, Sin_Complex) {
457  auto x_fn = [this](const int i) {
458    return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
459  };
460  auto dy_fn = [this](const complex64& x) {
461    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
462  };
463  auto dx_fn = [this](const complex64& x, const complex64& dy) {
464    return dy * conjugate(std::cos(x));
465  };
466  TestCWiseGrad<complex64>(SIN, x_fn, dy_fn, dx_fn);
467}
468
469TEST_F(CWiseUnaryGradTest, Cos) {
470  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
471  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
472  auto dx_fn = [this](const float x, const float dy) {
473    return dy * -1.0 * std::sin(x);
474  };
475  TestCWiseGrad<float>(COS, x_fn, dy_fn, dx_fn);
476}
477
478TEST_F(CWiseUnaryGradTest, Cos_Complex) {
479  auto x_fn = [this](const int i) {
480    return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
481  };
482  auto dy_fn = [this](const complex64& x) {
483    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
484  };
485  auto dx_fn = [this](const complex64& x, const complex64& dy) {
486    return dy * conjugate(-std::sin(x));
487  };
488  TestCWiseGrad<complex64>(COS, x_fn, dy_fn, dx_fn);
489}
490
491TEST_F(CWiseUnaryGradTest, Asin) {
492  auto x_fn = [this](const int i) { return RV({0, -0.5, 0.5, -1, 1}); };
493  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
494  auto dx_fn = [this](const float x, const float dy) {
495    return dy * (1.0 / std::sqrt(1.0 - x * x));
496  };
497  TestCWiseGrad<float>(ASIN, x_fn, dy_fn, dx_fn);
498}
499
500TEST_F(CWiseUnaryGradTest, Asin_Complex) {
501  auto x_fn = [this](const int i) {
502    return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
503  };
504  auto dy_fn = [this](const complex64& x) {
505    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
506  };
507  auto dx_fn = [this](const complex64& x, const complex64& dy) {
508    return dy / conjugate(std::sqrt(one_ - x * x));
509  };
510  // TODO(kbsriram)
511  // Enable test when the asin kernel supports complex numbers
512  if (false) {
513    TestCWiseGrad<complex64>(ASIN, x_fn, dy_fn, dx_fn);
514  }
515}
516
517TEST_F(CWiseUnaryGradTest, Acos) {
518  auto x_fn = [this](const int i) { return RV({0, -0.5, 0.5, -1, 1}); };
519  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
520  auto dx_fn = [this](const float x, const float dy) {
521    return dy * (-1.0 / std::sqrt(1.0 - x * x));
522  };
523  TestCWiseGrad<float>(ACOS, x_fn, dy_fn, dx_fn);
524}
525
526TEST_F(CWiseUnaryGradTest, Acos_Complex) {
527  auto x_fn = [this](const int i) {
528    return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
529  };
530  auto dy_fn = [this](const complex64& x) {
531    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
532  };
533  auto dx_fn = [this](const complex64& x, const complex64& dy) {
534    return dy / -conjugate(std::sqrt(one_ - x * x));
535  };
536  // TODO(kbsriram)
537  // Add test when the acos kernel supports complex numbers
538  if (false) {
539    TestCWiseGrad<complex64>(ACOS, x_fn, dy_fn, dx_fn);
540  }
541}
542
543TEST_F(CWiseUnaryGradTest, Tan) {
544  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
545  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
546  auto dx_fn = [this](const float x, const float dy) {
547    const float cosx = std::cos(x);
548    return dy * (1 / (cosx * cosx));
549  };
550  TestCWiseGrad<float>(TAN, x_fn, dy_fn, dx_fn);
551}
552
553TEST_F(CWiseUnaryGradTest, Tan_Complex) {
554  auto x_fn = [this](const int i) {
555    return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
556  };
557  auto dy_fn = [this](const complex64& x) {
558    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
559  };
560  auto dx_fn = [this](const complex64& x, const complex64& dy) {
561    const complex64 cosx = std::cos(x);
562    return dy / conjugate(cosx * cosx);
563  };
564  // TODO(kbsriram)
565  // Enable when tan kernel supports complex inputs
566  if (false) {
567    TestCWiseGrad<complex64>(TAN, x_fn, dy_fn, dx_fn);
568  }
569}
570
571TEST_F(CWiseUnaryGradTest, Atan) {
572  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
573  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
574  auto dx_fn = [this](const float x, const float dy) {
575    return dy * (1 / (1 + x * x));
576  };
577  TestCWiseGrad<float>(ATAN, x_fn, dy_fn, dx_fn);
578}
579
580TEST_F(CWiseUnaryGradTest, Atan_Complex) {
581  auto x_fn = [this](const int i) {
582    return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
583  };
584  auto dy_fn = [this](const complex64& x) {
585    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
586  };
587  auto dx_fn = [this](const complex64& x, const complex64& dy) {
588    return dy / (one_ + x * x);
589  };
590  // TODO(kbsriram)
591  // Add test when the atan kernel supports complex numbers
592  if (false) {
593    TestCWiseGrad<complex64>(ATAN, x_fn, dy_fn, dx_fn);
594  }
595}
596
597class CWiseUnaryComplexGradTest : public ::testing::Test {
598 protected:
599  CWiseUnaryComplexGradTest()
600      : scope_(Scope::NewRootScope().WithDevice("/cpu:0")) {}
601
602  enum UnaryOpType { REAL, IMAG, CONJ };
603
604  void TestCWiseGradComplex(UnaryOpType op_type, const Tensor& x,
605                            const Tensor& dy, const Tensor& dx_expected) {
606    Output y;
607    switch (op_type) {
608      case REAL:
609        y = Real(scope_, x);
610        break;
611      case IMAG:
612        y = Imag(scope_, x);
613        break;
614      case CONJ:
615        y = Conj(scope_, x);
616        break;
617    }
618
619    std::vector<Output> grad_outputs;
620    TF_ASSERT_OK(test::CallGradFunction(
621        scope_, Operation(y.node()), {ops::Const(scope_, dy)}, &grad_outputs));
622    Tensor dx;
623    test::GetTensor(scope_, grad_outputs[0], &dx);
624    test::ExpectClose(dx, dx_expected);
625  }
626
627  Scope scope_;
628};
629
630TEST_F(CWiseUnaryComplexGradTest, Real) {
631  Tensor x = test::AsTensor<complex64>(
632      {{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3});
633  Tensor dy = test::AsTensor<float>({11, -12, 13, -14, 15, -16}, {2, 3});
634  Tensor dx_expected = test::AsTensor<complex64>(
635      {{11, 0}, {-12, 0}, {13, 0}, {-14, 0}, {15, 0}, {-16, 0}}, {2, 3});
636  TestCWiseGradComplex(REAL, x, dy, dx_expected);
637}
638
639TEST_F(CWiseUnaryComplexGradTest, Imag) {
640  Tensor x = test::AsTensor<complex64>(
641      {{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3});
642  Tensor dy = test::AsTensor<float>({11, -12, 13, -14, 15, -16}, {2, 3});
643  Tensor dx_expected = test::AsTensor<complex64>(
644      {{0, 11}, {0, -12}, {0, 13}, {0, -14}, {0, 15}, {0, -16}}, {2, 3});
645  TestCWiseGradComplex(IMAG, x, dy, dx_expected);
646}
647
648TEST_F(CWiseUnaryComplexGradTest, Conj) {
649  Tensor x = test::AsTensor<complex64>(
650      {{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3});
651  Tensor dy = test::AsTensor<complex64>(
652      {{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3});
653  Tensor dx_expected = test::AsTensor<complex64>(
654      {{1, 1}, {-2, -2}, {3, 3}, {-4, -4}, {8, 8}, {-9, -9}}, {2, 3});
655  TestCWiseGradComplex(CONJ, x, dy, dx_expected);
656}
657
658class MathGradTest : public ::testing::Test {
659 protected:
660  MathGradTest() : root_(Scope::NewRootScope().WithDevice("/cpu:0")) {}
661
662  void TestMatMulGrad(const bool is_batch, const bool t_x, const bool t_y) {
663    // Generate random test data.
664    std::vector<Tensor> data;
665    RandMatMulGradData(is_batch, t_x, t_y, &data);
666    auto x = Const(root_, data[0]);
667    auto y = Const(root_, data[1]);
668    auto dz = Const(root_, data[2]);
669
670    std::vector<Tensor> grad_outputs;
671    ComputeMatMulGrad(is_batch, x, t_x, y, t_y, dz, &grad_outputs);
672
673    if (!t_x && !t_y) {
674      test::ExpectClose(grad_outputs[0],
675                        ComputeMatMul(is_batch, dz, false, y, true));
676      test::ExpectClose(grad_outputs[1],
677                        ComputeMatMul(is_batch, x, true, dz, false));
678    } else if (t_x && !t_y) {
679      test::ExpectClose(grad_outputs[0],
680                        ComputeMatMul(is_batch, y, false, dz, true));
681      test::ExpectClose(grad_outputs[1],
682                        ComputeMatMul(is_batch, x, false, dz, false));
683    } else if (!t_x && t_y) {
684      test::ExpectClose(grad_outputs[0],
685                        ComputeMatMul(is_batch, dz, false, y, false));
686      test::ExpectClose(grad_outputs[1],
687                        ComputeMatMul(is_batch, dz, true, x, false));
688    } else {
689      test::ExpectClose(grad_outputs[0],
690                        ComputeMatMul(is_batch, y, true, dz, true));
691      test::ExpectClose(grad_outputs[1],
692                        ComputeMatMul(is_batch, dz, true, x, true));
693    }
694  }
695
696  void ComputeMatMulGrad(const bool is_batch, const Output& x, const bool t_x,
697                         const Output& y, const bool t_y, const Output& dz,
698                         std::vector<Tensor>* out) {
699    // Compute forward MatMul: z = MatMul(x, y).
700    Output z;
701    if (is_batch) {
702      z = BatchMatMul(root_, x, y, BatchMatMul::AdjX(t_x).AdjY(t_y));
703    } else {
704      z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y));
705    }
706    TF_ASSERT_OK(root_.status());
707    CHECK_NOTNULL(z.node());
708    std::vector<Output> grad_outputs;
709    // Call MatMulGrad which populates 'grad_outputs'.
710    TF_ASSERT_OK(test::CallGradFunction(root_, Operation(z.node()), {dz},
711                                        &grad_outputs));
712    ASSERT_EQ(2, grad_outputs.size());
713    // Run graph and return MatMul gradient tensors for 'dx' and 'dy' in 'out'.
714    test::GetTensors(root_, {grad_outputs[0], grad_outputs[1]}, out);
715  }
716
717  Tensor ComputeMatMul(const bool is_batch, const Output& x, const bool t_x,
718                       const Output& y, const bool t_y) {
719    Output z;
720    if (is_batch) {
721      z = BatchMatMul(root_, x, y, BatchMatMul::AdjX(t_x).AdjY(t_y));
722    } else {
723      z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y));
724    }
725    TF_EXPECT_OK(root_.status());
726    Tensor out;
727    test::GetTensor(root_, z, &out);
728    return out;
729  }
730
731  void RandMatMulGradData(const bool is_batch, const bool tx, const bool ty,
732                          std::vector<Tensor>* data) {
733    // Choose a random batch size in [1, 4]
734    const int b = 1 + (random::New64() % 4);
735    // z = MatMul(x, y)
736    const int m = Rand();
737    const int k = Rand();
738    const int n = Rand();
739
740    TensorShape x_shape;
741    if (is_batch) {
742      // x.shape = [b, m, k]
743      x_shape = tx ? TensorShape({b, k, m}) : TensorShape({b, m, k});
744    } else {
745      // x.shape = [m, k]
746      x_shape = tx ? TensorShape({k, m}) : TensorShape({m, k});
747    }
748    data->emplace_back(DT_FLOAT, x_shape);
749    RandTensor(&data->back());
750
751    TensorShape y_shape;
752    if (is_batch) {
753      // y.shape = [b, k, n]
754      y_shape = ty ? TensorShape({b, n, k}) : TensorShape({b, k, n});
755    } else {
756      // y.shape = [k, n]
757      y_shape = ty ? TensorShape({n, k}) : TensorShape({k, n});
758    }
759    data->emplace_back(DT_FLOAT, y_shape);
760    RandTensor(&data->back());
761
762    TensorShape z_shape;
763    if (is_batch) {
764      // z.shape = [b, m, n]
765      z_shape = TensorShape({b, m, n});
766    } else {
767      // z.shape = [m, n]
768      z_shape = TensorShape({m, n});
769    }
770    data->emplace_back(DT_FLOAT, z_shape);
771    RandTensor(&data->back());
772  }
773
774  void RandTensor(Tensor* t) {
775    test::FillFn<float>(
776        t, [this](const int i) { return static_cast<float>(Rand()); });
777  }
778
779  int Rand() { return 1 + (random::New64() % 10); }
780
781  Scope root_;
782};
783
784TEST_F(MathGradTest, MatMulGrad_NoTranspose) {
785  TestMatMulGrad(false, false, false);
786}
787
788TEST_F(MathGradTest, MatMulGrad_TransposeX) {
789  TestMatMulGrad(false, true, false);
790}
791
792TEST_F(MathGradTest, MatMulGrad_TransposeY) {
793  TestMatMulGrad(false, false, true);
794}
795
796TEST_F(MathGradTest, MatMulGrad_TransposeX_TransposeY) {
797  TestMatMulGrad(false, true, true);
798}
799
800TEST_F(MathGradTest, BatchMatMulGrad_NoTranspose) {
801  TestMatMulGrad(true, false, false);
802}
803
804TEST_F(MathGradTest, BatchMatMulGrad_TransposeX) {
805  TestMatMulGrad(true, true, false);
806}
807
808TEST_F(MathGradTest, BatchMatMulGrad_TransposeY) {
809  TestMatMulGrad(true, false, true);
810}
811
812TEST_F(MathGradTest, BatchMatMulGrad_TransposeX_TransposeY) {
813  TestMatMulGrad(true, true, true);
814}
815
816}  // namespace
817}  // namespace tensorflow
818