math_grad_test.cc revision 1fa73c53ab95693f070ce70e6be0c644d83c163a
1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/cc/framework/grad_op_registry.h"
17#include "tensorflow/cc/framework/testutil.h"
18#include "tensorflow/cc/gradients/grad_testutil.h"
19#include "tensorflow/cc/ops/standard_ops.h"
20#include "tensorflow/core/framework/tensor_testutil.h"
21#include "tensorflow/core/lib/core/status_test_util.h"
22#include "tensorflow/core/lib/random/random.h"
23
24namespace tensorflow {
25using namespace ops;  // NOLINT(build/namespaces)
26
27namespace {
28
29// TODO(andydavis) Test gradient function against numeric gradients output.
30// TODO(andydavis) As more gradients are added move common test functions
31// to a testutil library.
32
33class CWiseUnaryGradTest : public ::testing::Test {
34 protected:
35  CWiseUnaryGradTest() : scope_(Scope::NewRootScope().WithDevice("/cpu:0")) {}
36
37  enum UnaryOpType {
38    ABS,
39    NEG,
40    INV,
41    SQUARE,
42    SQRT,
43    RSQRT,
44    EXP,
45    EXPM1,
46    LOG,
47    LOG1P,
48    TANH,
49    SIGMOID,
50    SIGN,
51    SIN,
52    COS,
53    ASIN,
54    ACOS,
55    TAN,
56    ATAN
57  };
58
59  template <typename T>
60  void TestCWiseGrad(UnaryOpType op_type, const std::function<T(int)>& x_fn,
61                     const std::function<T(const T&)>& dy_fn,
62                     const std::function<T(const T&, const T&)>& dx_fn) {
63    DataType dtype = DataTypeToEnum<T>::v();
64    Tensor x(dtype, {2, 3, 2});
65    auto x_flat = x.flat<T>();
66    for (int i = 0; i < x_flat.size(); ++i) {
67      x_flat(i) = x_fn(i);
68    }
69
70    Tensor dy(dtype, {2, 3, 2});
71    auto dy_flat = dy.flat<T>();
72    for (int i = 0; i < dy_flat.size(); ++i) {
73      dy_flat(i) = dy_fn(x_flat(i));
74    }
75
76    Tensor dx(dtype, {2, 3, 2});
77    auto dx_flat = dx.flat<T>();
78    for (int i = 0; i < dx_flat.size(); ++i) {
79      dx_flat(i) = dx_fn(x_flat(i), dy_flat(i));
80    }
81
82    Output y;
83    switch (op_type) {
84      case ABS:
85        y = Abs(scope_, x);
86        break;
87      case NEG:
88        y = Neg(scope_, x);
89        break;
90      case INV:
91        y = Reciprocal(scope_, x);
92        break;
93      case SQUARE:
94        y = Square(scope_, x);
95        break;
96      case SQRT:
97        y = Sqrt(scope_, x);
98        break;
99      case RSQRT:
100        y = Rsqrt(scope_, x);
101        break;
102      case EXP:
103        y = Exp(scope_, x);
104        break;
105      case EXPM1:
106        y = Expm1(scope_, x);
107        break;
108      case LOG:
109        y = Log(scope_, x);
110        break;
111      case LOG1P:
112        y = Log1p(scope_, x);
113        break;
114      case TANH:
115        y = Tanh(scope_, x);
116        break;
117      case SIGMOID:
118        y = Sigmoid(scope_, x);
119        break;
120      case SIGN:
121        y = Sign(scope_, x);
122        break;
123      case SIN:
124        y = Sin(scope_, x);
125        break;
126      case COS:
127        y = Cos(scope_, x);
128        break;
129      case ASIN:
130        y = Asin(scope_, x);
131        break;
132      case ACOS:
133        y = Acos(scope_, x);
134        break;
135      case TAN:
136        y = Tan(scope_, x);
137        break;
138      case ATAN:
139        y = Atan(scope_, x);
140        break;
141    }
142
143    std::vector<Output> grad_outputs;
144    TF_ASSERT_OK(test::CallGradFunction(
145        scope_, Operation(y.node()), {ops::Const(scope_, dy)}, &grad_outputs));
146    Tensor output;
147    test::GetTensor(scope_, grad_outputs[0], &output);
148    test::ExpectClose(output, dx);
149  }
150
151  float RV(const std::vector<float>& v) {
152    return v[random::New64() % v.size()];
153  }
154
155  complex64 CRV(const std::vector<complex64>& v) {
156    return v[random::New64() % v.size()];
157  }
158
159  complex64 conjugate(const complex64& val) {
160    return complex64(val.real(), -val.imag());
161  }
162
163  const complex64 one_{1.0, 0};
164
165  Scope scope_;
166};
167
168TEST_F(CWiseUnaryGradTest, Abs) {
169  auto x_fn = [this](const int i) { return RV({-1, 0, 1}); };
170  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
171  auto dx_fn = [this](const float x, const float dy) { return x * dy; };
172  TestCWiseGrad<float>(ABS, x_fn, dy_fn, dx_fn);
173}
174
175TEST_F(CWiseUnaryGradTest, Neg) {
176  auto x_fn = [this](const int i) { return RV({-1, 0, 1}); };
177  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
178  auto dx_fn = [this](const float x, const float dy) { return -dy; };
179  TestCWiseGrad<float>(NEG, x_fn, dy_fn, dx_fn);
180}
181
182TEST_F(CWiseUnaryGradTest, Reciprocal) {
183  auto x_fn = [this](const int i) { return RV({-1, 1, -2, 2, -3, 3, -4, 4}); };
184  auto dy_fn = [this](const float x) { return RV({0, -2, 2, -3, 3, -4, 4}); };
185  auto dx_fn = [this](const float x, const float dy) {
186    return -(1 / (x * x)) * dy;
187  };
188  TestCWiseGrad<float>(INV, x_fn, dy_fn, dx_fn);
189}
190
191TEST_F(CWiseUnaryGradTest, Reciprocal_Complex) {
192  auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); };
193  auto dy_fn = [this](const complex64 x) {
194    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
195  };
196  auto dx_fn = [this](const complex64 x, const complex64 dy) {
197    return -conjugate(one_ / (x * x)) * dy;
198  };
199  TestCWiseGrad<complex64>(INV, x_fn, dy_fn, dx_fn);
200}
201
202TEST_F(CWiseUnaryGradTest, Square) {
203  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
204  auto dy_fn = [this](const float x) { return RV({0, -7, 7, -8, 8, -9, 9}); };
205  auto dx_fn = [this](const float x, const float dy) { return 2 * x * dy; };
206  TestCWiseGrad<float>(SQUARE, x_fn, dy_fn, dx_fn);
207}
208
209TEST_F(CWiseUnaryGradTest, Square_Complex) {
210  auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); };
211  auto dy_fn = [this](const complex64& x) {
212    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
213  };
214  auto dx_fn = [this](const complex64& x, const complex64& dy) {
215    return conjugate(complex64(2, 0) * x) * dy;
216  };
217  TestCWiseGrad<complex64>(SQUARE, x_fn, dy_fn, dx_fn);
218}
219
220TEST_F(CWiseUnaryGradTest, Sqrt) {
221  auto x_fn = [this](const int i) { return RV({0, 1, 2, 3, 4, 5, 6, 7}); };
222  auto dy_fn = [this](const float x) { return x + RV({8, 9, 10, 11, 12, 13}); };
223  auto dx_fn = [this](const float x, const float dy) {
224    return dy * 0.5 * (1.0 / std::sqrt(x));
225  };
226  TestCWiseGrad<float>(SQRT, x_fn, dy_fn, dx_fn);
227}
228
229TEST_F(CWiseUnaryGradTest, Sqrt_Complex) {
230  auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); };
231  auto dy_fn = [this](const complex64& x) {
232    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
233  };
234  auto dx_fn = [this](const complex64& x, const complex64& dy) {
235    return conjugate(complex64(0.5, 0) / std::sqrt(x)) * dy;
236  };
237  TestCWiseGrad<complex64>(SQRT, x_fn, dy_fn, dx_fn);
238}
239
240TEST_F(CWiseUnaryGradTest, Rsqrt) {
241  auto x_fn = [this](const int i) { return RV({1, 2, 3, 4, 5, 6, 7, 8}); };
242  auto dy_fn = [this](const float x) { return x + RV({8, 9, 10, 11, 12, 13}); };
243  auto dx_fn = [this](const float x, const float dy) {
244    return dy * -0.5 * (1 / std::sqrt(x)) * (1 / x);
245  };
246  TestCWiseGrad<float>(RSQRT, x_fn, dy_fn, dx_fn);
247}
248
249TEST_F(CWiseUnaryGradTest, Rsqrt_Complex) {
250  auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); };
251  auto dy_fn = [this](const complex64& x) {
252    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
253  };
254  auto dx_fn = [this](const complex64& x, const complex64& dy) {
255    return conjugate(complex64(-0.5, 0) / std::sqrt(x) / x) * dy;
256  };
257  TestCWiseGrad<complex64>(RSQRT, x_fn, dy_fn, dx_fn);
258}
259
260TEST_F(CWiseUnaryGradTest, Exp) {
261  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
262  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
263  auto dx_fn = [this](const float x, const float dy) {
264    return dy * std::exp(x);
265  };
266  TestCWiseGrad<float>(EXP, x_fn, dy_fn, dx_fn);
267}
268
269TEST_F(CWiseUnaryGradTest, Exp_Complex) {
270  auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); };
271  auto dy_fn = [this](const complex64& x) {
272    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
273  };
274  auto dx_fn = [this](const complex64& x, const complex64& dy) {
275    return dy * conjugate(std::exp(x));
276  };
277  TestCWiseGrad<complex64>(EXP, x_fn, dy_fn, dx_fn);
278}
279
280TEST_F(CWiseUnaryGradTest, Expm1) {
281  auto x_fn = [this](const int i) { return RV({0, -1, 1e-6, 1, -2, 3, 100}); };
282  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
283  auto dx_fn = [this](const float x, const float dy) {
284    return dy * std::exp(x);
285  };
286  TestCWiseGrad<float>(EXPM1, x_fn, dy_fn, dx_fn);
287}
288
289TEST_F(CWiseUnaryGradTest, Expm1_Complex) {
290  auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); };
291  auto dy_fn = [this](const complex64& x) {
292    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
293  };
294  auto dx_fn = [this](const complex64& x, const complex64& dy) {
295    return dy * conjugate(std::exp(x));
296  };
297  TestCWiseGrad<complex64>(EXPM1, x_fn, dy_fn, dx_fn);
298}
299
300TEST_F(CWiseUnaryGradTest, Log) {
301  auto x_fn = [this](const int i) { return RV({-1, 1, -2, 2, -3, 3, -4, 4}); };
302  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
303  auto dx_fn = [this](const float x, const float dy) { return dy * (1.0 / x); };
304  TestCWiseGrad<float>(LOG, x_fn, dy_fn, dx_fn);
305}
306
307TEST_F(CWiseUnaryGradTest, Log_Complex) {
308  auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); };
309  auto dy_fn = [this](const complex64& x) {
310    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
311  };
312  auto dx_fn = [this](const complex64& x, const complex64& dy) {
313    return dy * conjugate(one_ / x);
314  };
315  TestCWiseGrad<complex64>(LOG, x_fn, dy_fn, dx_fn);
316}
317
318TEST_F(CWiseUnaryGradTest, Log1p) {
319  auto x_fn = [this](const int i) { return RV({0, 1e-6, 1, 2, 3, 4, 100}); };
320  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
321  auto dx_fn = [this](const float x, const float dy) {
322    return dy * (1.0 / (1.0 + x));
323  };
324  TestCWiseGrad<float>(LOG1P, x_fn, dy_fn, dx_fn);
325}
326
327TEST_F(CWiseUnaryGradTest, Log1p_Complex) {
328  auto x_fn = [this](const int i) {
329    return CRV({{0, 0}, {1e-6, 0}, {2, -1}, {1, 2}, {3, 4}});
330  };
331  auto dy_fn = [this](const complex64& x) {
332    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
333  };
334  auto dx_fn = [this](const complex64& x, const complex64& dy) {
335    return dy / (one_ + conjugate(x));
336  };
337  TestCWiseGrad<complex64>(LOG1P, x_fn, dy_fn, dx_fn);
338}
339
340TEST_F(CWiseUnaryGradTest, Tanh) {
341  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
342  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
343  auto dx_fn = [this](const float x, const float dy) {
344    const float y = std::tanh(x);
345    return dy * (1.0 - y * y);
346  };
347  TestCWiseGrad<float>(TANH, x_fn, dy_fn, dx_fn);
348}
349
350TEST_F(CWiseUnaryGradTest, Tanh_Complex) {
351  auto x_fn = [this](const int i) {
352    return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
353  };
354  auto dy_fn = [this](const complex64& x) {
355    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
356  };
357  auto dx_fn = [this](const complex64& x, const complex64& dy) {
358    const complex64 y = std::tanh(x);
359    return dy * conjugate((one_ - y * y));
360  };
361  TestCWiseGrad<complex64>(TANH, x_fn, dy_fn, dx_fn);
362}
363
364TEST_F(CWiseUnaryGradTest, Sigmoid) {
365  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
366  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
367  auto dx_fn = [this](const float x, const float dy) {
368    const float y = 1.0 / (1.0 + std::exp(-x));
369    return dy * y * (1.0 - y);
370  };
371  TestCWiseGrad<float>(SIGMOID, x_fn, dy_fn, dx_fn);
372}
373
374TEST_F(CWiseUnaryGradTest, Sigmoid_Complex) {
375  auto x_fn = [this](const int i) {
376    return CRV({{1, 0}, {0, 0}, {2, -1}, {1, 2}, {3, 4}});
377  };
378  auto dy_fn = [this](const complex64& x) {
379    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
380  };
381  auto dx_fn = [this](const complex64& x, const complex64& dy) {
382    const complex64 y = one_ / (one_ + std::exp(-x));
383    return dy * conjugate(y * (one_ - y));
384  };
385  TestCWiseGrad<complex64>(SIGMOID, x_fn, dy_fn, dx_fn);
386}
387
388TEST_F(CWiseUnaryGradTest, Sign) {
389  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
390  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
391  auto dx_fn = [this](const float x, const float dy) { return 0.0; };
392  TestCWiseGrad<float>(SIGN, x_fn, dy_fn, dx_fn);
393}
394
395TEST_F(CWiseUnaryGradTest, Sin) {
396  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
397  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
398  auto dx_fn = [this](const float x, const float dy) {
399    return dy * std::cos(x);
400  };
401  TestCWiseGrad<float>(SIN, x_fn, dy_fn, dx_fn);
402}
403
404TEST_F(CWiseUnaryGradTest, Sin_Complex) {
405  auto x_fn = [this](const int i) {
406    return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
407  };
408  auto dy_fn = [this](const complex64& x) {
409    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
410  };
411  auto dx_fn = [this](const complex64& x, const complex64& dy) {
412    return dy * conjugate(std::cos(x));
413  };
414  TestCWiseGrad<complex64>(SIN, x_fn, dy_fn, dx_fn);
415}
416
417TEST_F(CWiseUnaryGradTest, Cos) {
418  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
419  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
420  auto dx_fn = [this](const float x, const float dy) {
421    return dy * -1.0 * std::sin(x);
422  };
423  TestCWiseGrad<float>(COS, x_fn, dy_fn, dx_fn);
424}
425
426TEST_F(CWiseUnaryGradTest, Cos_Complex) {
427  auto x_fn = [this](const int i) {
428    return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
429  };
430  auto dy_fn = [this](const complex64& x) {
431    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
432  };
433  auto dx_fn = [this](const complex64& x, const complex64& dy) {
434    return dy * conjugate(-std::sin(x));
435  };
436  TestCWiseGrad<complex64>(COS, x_fn, dy_fn, dx_fn);
437}
438
439TEST_F(CWiseUnaryGradTest, Asin) {
440  auto x_fn = [this](const int i) { return RV({0, -0.5, 0.5, -1, 1}); };
441  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
442  auto dx_fn = [this](const float x, const float dy) {
443    return dy * (1.0 / std::sqrt(1.0 - x * x));
444  };
445  TestCWiseGrad<float>(ASIN, x_fn, dy_fn, dx_fn);
446}
447
448TEST_F(CWiseUnaryGradTest, Asin_Complex) {
449  auto x_fn = [this](const int i) {
450    return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
451  };
452  auto dy_fn = [this](const complex64& x) {
453    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
454  };
455  auto dx_fn = [this](const complex64& x, const complex64& dy) {
456    return dy / conjugate(std::sqrt(one_ - x * x));
457  };
458  // TODO(kbsriram)
459  // Enable test when the asin kernel supports complex numbers
460  if (false) {
461    TestCWiseGrad<complex64>(ASIN, x_fn, dy_fn, dx_fn);
462  }
463}
464
465TEST_F(CWiseUnaryGradTest, Acos) {
466  auto x_fn = [this](const int i) { return RV({0, -0.5, 0.5, -1, 1}); };
467  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
468  auto dx_fn = [this](const float x, const float dy) {
469    return dy * (-1.0 / std::sqrt(1.0 - x * x));
470  };
471  TestCWiseGrad<float>(ACOS, x_fn, dy_fn, dx_fn);
472}
473
474TEST_F(CWiseUnaryGradTest, Acos_Complex) {
475  auto x_fn = [this](const int i) {
476    return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
477  };
478  auto dy_fn = [this](const complex64& x) {
479    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
480  };
481  auto dx_fn = [this](const complex64& x, const complex64& dy) {
482    return dy / -conjugate(std::sqrt(one_ - x * x));
483  };
484  // TODO(kbsriram)
485  // Add test when the acos kernel supports complex numbers
486  if (false) {
487    TestCWiseGrad<complex64>(ACOS, x_fn, dy_fn, dx_fn);
488  }
489}
490
491TEST_F(CWiseUnaryGradTest, Tan) {
492  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
493  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
494  auto dx_fn = [this](const float x, const float dy) {
495    const float cosx = std::cos(x);
496    return dy * (1 / (cosx * cosx));
497  };
498  TestCWiseGrad<float>(TAN, x_fn, dy_fn, dx_fn);
499}
500
501TEST_F(CWiseUnaryGradTest, Tan_Complex) {
502  auto x_fn = [this](const int i) {
503    return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
504  };
505  auto dy_fn = [this](const complex64& x) {
506    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
507  };
508  auto dx_fn = [this](const complex64& x, const complex64& dy) {
509    const complex64 cosx = std::cos(x);
510    return dy / conjugate(cosx * cosx);
511  };
512  // TODO(kbsriram)
513  // Enable when tan kernel supports complex inputs
514  if (false) {
515    TestCWiseGrad<complex64>(TAN, x_fn, dy_fn, dx_fn);
516  }
517}
518
519TEST_F(CWiseUnaryGradTest, Atan) {
520  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
521  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
522  auto dx_fn = [this](const float x, const float dy) {
523    return dy * (1 / (1 + x * x));
524  };
525  TestCWiseGrad<float>(ATAN, x_fn, dy_fn, dx_fn);
526}
527
528TEST_F(CWiseUnaryGradTest, Atan_Complex) {
529  auto x_fn = [this](const int i) {
530    return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
531  };
532  auto dy_fn = [this](const complex64& x) {
533    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
534  };
535  auto dx_fn = [this](const complex64& x, const complex64& dy) {
536    return dy / (one_ + x * x);
537  };
538  // TODO(kbsriram)
539  // Add test when the atan kernel supports complex numbers
540  if (false) {
541    TestCWiseGrad<complex64>(ATAN, x_fn, dy_fn, dx_fn);
542  }
543}
544
545class CWiseUnaryComplexGradTest : public ::testing::Test {
546 protected:
547  CWiseUnaryComplexGradTest()
548      : scope_(Scope::NewRootScope().WithDevice("/cpu:0")) {}
549
550  enum UnaryOpType { REAL, IMAG, CONJ };
551
552  void TestCWiseGradComplex(UnaryOpType op_type, const Tensor& x,
553                            const Tensor& dy, const Tensor& dx_expected) {
554    Output y;
555    switch (op_type) {
556      case REAL:
557        y = Real(scope_, x);
558        break;
559      case IMAG:
560        y = Imag(scope_, x);
561        break;
562      case CONJ:
563        y = Conj(scope_, x);
564        break;
565    }
566
567    std::vector<Output> grad_outputs;
568    TF_ASSERT_OK(test::CallGradFunction(
569        scope_, Operation(y.node()), {ops::Const(scope_, dy)}, &grad_outputs));
570    Tensor dx;
571    test::GetTensor(scope_, grad_outputs[0], &dx);
572    test::ExpectClose(dx, dx_expected);
573  }
574
575  Scope scope_;
576};
577
578TEST_F(CWiseUnaryComplexGradTest, Real) {
579  Tensor x = test::AsTensor<complex64>(
580      {{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3});
581  Tensor dy = test::AsTensor<float>({11, -12, 13, -14, 15, -16}, {2, 3});
582  Tensor dx_expected = test::AsTensor<complex64>(
583      {{11, 0}, {-12, 0}, {13, 0}, {-14, 0}, {15, 0}, {-16, 0}}, {2, 3});
584  TestCWiseGradComplex(REAL, x, dy, dx_expected);
585}
586
587TEST_F(CWiseUnaryComplexGradTest, Imag) {
588  Tensor x = test::AsTensor<complex64>(
589      {{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3});
590  Tensor dy = test::AsTensor<float>({11, -12, 13, -14, 15, -16}, {2, 3});
591  Tensor dx_expected = test::AsTensor<complex64>(
592      {{0, 11}, {0, -12}, {0, 13}, {0, -14}, {0, 15}, {0, -16}}, {2, 3});
593  TestCWiseGradComplex(IMAG, x, dy, dx_expected);
594}
595
596TEST_F(CWiseUnaryComplexGradTest, Conj) {
597  Tensor x = test::AsTensor<complex64>(
598      {{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3});
599  Tensor dy = test::AsTensor<complex64>(
600      {{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3});
601  Tensor dx_expected = test::AsTensor<complex64>(
602      {{1, 1}, {-2, -2}, {3, 3}, {-4, -4}, {8, 8}, {-9, -9}}, {2, 3});
603  TestCWiseGradComplex(CONJ, x, dy, dx_expected);
604}
605
606class MathGradTest : public ::testing::Test {
607 protected:
608  MathGradTest() : root_(Scope::NewRootScope().WithDevice("/cpu:0")) {}
609
610  void TestMatMulGrad(const bool is_batch, const bool t_x, const bool t_y) {
611    // Generate random test data.
612    std::vector<Tensor> data;
613    RandMatMulGradData(is_batch, t_x, t_y, &data);
614    auto x = Const(root_, data[0]);
615    auto y = Const(root_, data[1]);
616    auto dz = Const(root_, data[2]);
617
618    std::vector<Tensor> grad_outputs;
619    ComputeMatMulGrad(is_batch, x, t_x, y, t_y, dz, &grad_outputs);
620
621    if (!t_x && !t_y) {
622      test::ExpectClose(grad_outputs[0],
623                        ComputeMatMul(is_batch, dz, false, y, true));
624      test::ExpectClose(grad_outputs[1],
625                        ComputeMatMul(is_batch, x, true, dz, false));
626    } else if (t_x && !t_y) {
627      test::ExpectClose(grad_outputs[0],
628                        ComputeMatMul(is_batch, y, false, dz, true));
629      test::ExpectClose(grad_outputs[1],
630                        ComputeMatMul(is_batch, x, false, dz, false));
631    } else if (!t_x && t_y) {
632      test::ExpectClose(grad_outputs[0],
633                        ComputeMatMul(is_batch, dz, false, y, false));
634      test::ExpectClose(grad_outputs[1],
635                        ComputeMatMul(is_batch, dz, true, x, false));
636    } else {
637      test::ExpectClose(grad_outputs[0],
638                        ComputeMatMul(is_batch, y, true, dz, true));
639      test::ExpectClose(grad_outputs[1],
640                        ComputeMatMul(is_batch, dz, true, x, true));
641    }
642  }
643
644  void ComputeMatMulGrad(const bool is_batch, const Output& x, const bool t_x,
645                         const Output& y, const bool t_y, const Output& dz,
646                         std::vector<Tensor>* out) {
647    // Compute forward MatMul: z = MatMul(x, y).
648    Output z;
649    if (is_batch) {
650      z = BatchMatMul(root_, x, y, BatchMatMul::AdjX(t_x).AdjY(t_y));
651    } else {
652      z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y));
653    }
654    TF_ASSERT_OK(root_.status());
655    CHECK_NOTNULL(z.node());
656    std::vector<Output> grad_outputs;
657    // Call MatMulGrad which populates 'grad_outputs'.
658    TF_ASSERT_OK(test::CallGradFunction(root_, Operation(z.node()), {dz},
659                                        &grad_outputs));
660    ASSERT_EQ(2, grad_outputs.size());
661    // Run graph and return MatMul gradient tensors for 'dx' and 'dy' in 'out'.
662    test::GetTensors(root_, {grad_outputs[0], grad_outputs[1]}, out);
663  }
664
665  Tensor ComputeMatMul(const bool is_batch, const Output& x, const bool t_x,
666                       const Output& y, const bool t_y) {
667    Output z;
668    if (is_batch) {
669      z = BatchMatMul(root_, x, y, BatchMatMul::AdjX(t_x).AdjY(t_y));
670    } else {
671      z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y));
672    }
673    TF_EXPECT_OK(root_.status());
674    Tensor out;
675    test::GetTensor(root_, z, &out);
676    return out;
677  }
678
679  void RandMatMulGradData(const bool is_batch, const bool tx, const bool ty,
680                          std::vector<Tensor>* data) {
681    // Choose a random batch size in [1, 4]
682    const int b = 1 + (random::New64() % 4);
683    // z = MatMul(x, y)
684    const int m = Rand();
685    const int k = Rand();
686    const int n = Rand();
687
688    TensorShape x_shape;
689    if (is_batch) {
690      // x.shape = [b, m, k]
691      x_shape = tx ? TensorShape({b, k, m}) : TensorShape({b, m, k});
692    } else {
693      // x.shape = [m, k]
694      x_shape = tx ? TensorShape({k, m}) : TensorShape({m, k});
695    }
696    data->emplace_back(DT_FLOAT, x_shape);
697    RandTensor(&data->back());
698
699    TensorShape y_shape;
700    if (is_batch) {
701      // y.shape = [b, k, n]
702      y_shape = ty ? TensorShape({b, n, k}) : TensorShape({b, k, n});
703    } else {
704      // y.shape = [k, n]
705      y_shape = ty ? TensorShape({n, k}) : TensorShape({k, n});
706    }
707    data->emplace_back(DT_FLOAT, y_shape);
708    RandTensor(&data->back());
709
710    TensorShape z_shape;
711    if (is_batch) {
712      // z.shape = [b, m, n]
713      z_shape = TensorShape({b, m, n});
714    } else {
715      // z.shape = [m, n]
716      z_shape = TensorShape({m, n});
717    }
718    data->emplace_back(DT_FLOAT, z_shape);
719    RandTensor(&data->back());
720  }
721
722  void RandTensor(Tensor* t) {
723    test::FillFn<float>(
724        t, [this](const int i) { return static_cast<float>(Rand()); });
725  }
726
727  int Rand() { return 1 + (random::New64() % 10); }
728
729  Scope root_;
730};
731
732TEST_F(MathGradTest, MatMulGrad_NoTranspose) {
733  TestMatMulGrad(false, false, false);
734}
735
736TEST_F(MathGradTest, MatMulGrad_TransposeX) {
737  TestMatMulGrad(false, true, false);
738}
739
740TEST_F(MathGradTest, MatMulGrad_TransposeY) {
741  TestMatMulGrad(false, false, true);
742}
743
744TEST_F(MathGradTest, MatMulGrad_TransposeX_TransposeY) {
745  TestMatMulGrad(false, true, true);
746}
747
748TEST_F(MathGradTest, BatchMatMulGrad_NoTranspose) {
749  TestMatMulGrad(true, false, false);
750}
751
752TEST_F(MathGradTest, BatchMatMulGrad_TransposeX) {
753  TestMatMulGrad(true, true, false);
754}
755
756TEST_F(MathGradTest, BatchMatMulGrad_TransposeY) {
757  TestMatMulGrad(true, false, true);
758}
759
760TEST_F(MathGradTest, BatchMatMulGrad_TransposeX_TransposeY) {
761  TestMatMulGrad(true, true, true);
762}
763
764}  // namespace
765}  // namespace tensorflow
766