math_grad_test.cc revision a373b1f74215e44920bf9362a51bece530edf88a
1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/cc/framework/grad_op_registry.h"
17#include "tensorflow/cc/framework/gradient_checker.h"
18#include "tensorflow/cc/framework/testutil.h"
19#include "tensorflow/cc/gradients/grad_testutil.h"
20#include "tensorflow/cc/ops/standard_ops.h"
21#include "tensorflow/core/framework/tensor_testutil.h"
22#include "tensorflow/core/lib/core/status_test_util.h"
23#include "tensorflow/core/lib/random/random.h"
24
25namespace tensorflow {
26using namespace ops;  // NOLINT(build/namespaces)
27
28namespace {
29
30// TODO(andydavis) Test gradient function against numeric gradients output.
31// TODO(andydavis) As more gradients are added move common test functions
32// to a testutil library.
33
34class CWiseUnaryGradTest : public ::testing::Test {
35 protected:
36  CWiseUnaryGradTest() : scope_(Scope::NewRootScope().WithDevice("/cpu:0")) {}
37
38  enum UnaryOpType {
39    ABS,
40    NEG,
41    INV,
42    SQUARE,
43    SQRT,
44    RSQRT,
45    EXP,
46    EXPM1,
47    LOG,
48    LOG1P,
49    SINH,
50    COSH,
51    TANH,
52    ASINH,
53    ACOSH,
54    ATANH,
55    SIGMOID,
56    SIGN,
57    SIN,
58    COS,
59    ASIN,
60    ACOS,
61    TAN,
62    ATAN,
63    REAL,
64    IMAG,
65    CONJ,
66    COMPLEX,
67    ANGLE
68  };
69
70  template <typename X_T, typename Y_T>
71  void TestCWiseGrad(UnaryOpType op_type, const std::function<X_T(int)>& x_fn) {
72    TF_ASSERT_OK(scope_.status());
73    DataType x_type = DataTypeToEnum<X_T>::v();
74    TensorShape shape({2, 3, 2});
75    auto x = Placeholder(scope_, x_type, Placeholder::Shape(shape));
76    Tensor x_data(x_type, shape);
77    auto x_data_flat = x_data.flat<X_T>();
78    for (int i = 0; i < x_data_flat.size(); ++i) {
79      x_data_flat(i) = x_fn(i);
80    }
81
82    Output y;
83    switch (op_type) {
84      case ABS:
85        y = Abs(scope_, x);
86        break;
87      case NEG:
88        y = Neg(scope_, x);
89        break;
90      case INV:
91        y = Reciprocal(scope_, x);
92        break;
93      case SQUARE:
94        y = Square(scope_, x);
95        break;
96      case SQRT:
97        y = Sqrt(scope_, x);
98        break;
99      case RSQRT:
100        y = Rsqrt(scope_, x);
101        break;
102      case EXP:
103        y = Exp(scope_, x);
104        break;
105      case EXPM1:
106        y = Expm1(scope_, x);
107        break;
108      case LOG:
109        y = Log(scope_, x);
110        break;
111      case LOG1P:
112        y = Log1p(scope_, x);
113        break;
114      case SINH:
115        y = Sinh(scope_, x);
116        break;
117      case COSH:
118        y = Cosh(scope_, x);
119        break;
120      case TANH:
121        y = Tanh(scope_, x);
122        break;
123      case ASINH:
124        y = Asinh(scope_, x);
125        break;
126      case ACOSH:
127        y = Acosh(scope_, x);
128        break;
129      case ATANH:
130        y = Atanh(scope_, x);
131        break;
132      case SIGMOID:
133        y = Sigmoid(scope_, x);
134        break;
135      case SIGN:
136        y = Sign(scope_, x);
137        break;
138      case SIN:
139        y = Sin(scope_, x);
140        break;
141      case COS:
142        y = Cos(scope_, x);
143        break;
144      case ASIN:
145        y = Asin(scope_, x);
146        break;
147      case ACOS:
148        y = Acos(scope_, x);
149        break;
150      case TAN:
151        y = Tan(scope_, x);
152        break;
153      case ATAN:
154        y = Atan(scope_, x);
155        break;
156      case REAL:
157        y = Real(scope_, x);
158        break;
159      case IMAG:
160        y = Imag(scope_, x);
161        break;
162      case CONJ:
163        y = Conj(scope_, x);
164        break;
165      case COMPLEX:
166        y = Complex(scope_, x, x);
167        break;
168      case ANGLE:
169        y = Angle(scope_, x);
170        break;
171    }
172
173    float max_error;
174    TF_ASSERT_OK((ComputeGradientError<X_T, Y_T, float>(scope_, x, x_data, y,
175                                                        shape, &max_error)));
176    EXPECT_LT(max_error, 1e-3f);
177  }
178
179  float RV(const std::vector<float>& v) {
180    return v[random::New64() % v.size()];
181  }
182
183  complex64 CRV(const std::vector<complex64>& v) {
184    return v[random::New64() % v.size()];
185  }
186
187  complex64 conjugate(const complex64& val) {
188    return complex64(val.real(), -val.imag());
189  }
190
191  Scope scope_;
192};
193
194TEST_F(CWiseUnaryGradTest, Abs) {
195  auto x_fn = [this](const int i) { return RV({-1, 0, 1}); };
196  TestCWiseGrad<float, float>(ABS, x_fn);
197}
198
199TEST_F(CWiseUnaryGradTest, Neg) {
200  auto x_fn = [this](const int i) { return RV({-1, 0, 1}); };
201  TestCWiseGrad<float, float>(NEG, x_fn);
202}
203
204TEST_F(CWiseUnaryGradTest, Reciprocal) {
205  auto x_fn = [this](const int i) { return RV({-1, 1, -2, 2, -3, 3, -4, 4}); };
206  TestCWiseGrad<float, float>(INV, x_fn);
207}
208
209TEST_F(CWiseUnaryGradTest, Reciprocal_Complex) {
210  auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); };
211  TestCWiseGrad<complex64, complex64>(INV, x_fn);
212}
213
214TEST_F(CWiseUnaryGradTest, Square) {
215  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
216  TestCWiseGrad<float, float>(SQUARE, x_fn);
217}
218
219TEST_F(CWiseUnaryGradTest, Square_Complex) {
220  auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); };
221  TestCWiseGrad<complex64, complex64>(SQUARE, x_fn);
222}
223
224TEST_F(CWiseUnaryGradTest, Sqrt) {
225  auto x_fn = [this](const int i) { return RV({0.5, 1, 2, 3, 4, 5, 6, 7}); };
226  TestCWiseGrad<float, float>(SQRT, x_fn);
227}
228
229TEST_F(CWiseUnaryGradTest, Sqrt_Complex) {
230  auto x_fn = [this](const int i) {
231    return CRV({{-1.0f, 0.5f}, {1.0f, 0.5f}, {2, -1}});
232  };
233  TestCWiseGrad<complex64, complex64>(SQRT, x_fn);
234}
235
236TEST_F(CWiseUnaryGradTest, Rsqrt) {
237  auto x_fn = [this](const int i) { return RV({1, 2, 3, 4, 5, 6, 7, 8}); };
238  TestCWiseGrad<float, float>(RSQRT, x_fn);
239}
240
241TEST_F(CWiseUnaryGradTest, Rsqrt_Complex) {
242  auto x_fn = [this](const int i) {
243    return CRV({{-1.0f, 0.5f}, {1.0f, 0.5f}, {2, -1}});
244  };
245  TestCWiseGrad<complex64, complex64>(RSQRT, x_fn);
246}
247
248TEST_F(CWiseUnaryGradTest, Exp) {
249  auto x_fn = [this](const int i) {
250    return RV({0, -1, 1, -1.5f, 1.5f, -2, 2});
251  };
252  TestCWiseGrad<float, float>(EXP, x_fn);
253}
254
255TEST_F(CWiseUnaryGradTest, Exp_Complex) {
256  auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); };
257  TestCWiseGrad<complex64, complex64>(EXP, x_fn);
258}
259
260TEST_F(CWiseUnaryGradTest, Expm1) {
261  auto x_fn = [this](const int i) { return RV({0, -1, 1e-6, 1, -1.5, 1.5}); };
262  TestCWiseGrad<float, float>(EXPM1, x_fn);
263}
264
265TEST_F(CWiseUnaryGradTest, Expm1_Complex) {
266  auto x_fn = [this](const int i) {
267    return CRV({{-1, 0}, {1, 0}, {1.5, -1.5}});
268  };
269  TestCWiseGrad<complex64, complex64>(EXPM1, x_fn);
270}
271
272TEST_F(CWiseUnaryGradTest, Log) {
273  auto x_fn = [this](const int i) { return RV({0.5, 1, 2, 3, 4}); };
274  TestCWiseGrad<float, float>(LOG, x_fn);
275}
276
277TEST_F(CWiseUnaryGradTest, Log_Complex) {
278  auto x_fn = [this](const int i) {
279    return CRV({{-1, 0.5f}, {1, 0.5f}, {2, -1}});
280  };
281  TestCWiseGrad<complex64, complex64>(LOG, x_fn);
282}
283
284TEST_F(CWiseUnaryGradTest, Log1p) {
285  auto x_fn = [this](const int i) { return RV({0, 1e-6, 1, 2, 3, 4, 100}); };
286  TestCWiseGrad<float, float>(LOG1P, x_fn);
287}
288
289TEST_F(CWiseUnaryGradTest, Log1p_Complex) {
290  auto x_fn = [this](const int i) {
291    return CRV({{0, 0}, {1e-6, 0}, {2, -1}, {1, 2}, {3, 4}});
292  };
293  TestCWiseGrad<complex64, complex64>(LOG1P, x_fn);
294}
295
296TEST_F(CWiseUnaryGradTest, Sinh) {
297  auto x_fn = [this](const int i) { return RV({0.5, -0.5, 1, -1, 1.5, -1.5}); };
298  TestCWiseGrad<float, float>(SINH, x_fn);
299}
300
301TEST_F(CWiseUnaryGradTest, Sinh_Complex) {
302  auto x_fn = [this](const int i) {
303    return CRV({{0.5, 0.25}, {0.25, 0.5}, {1.5, -1}, {1, 1.5}});
304  };
305  TestCWiseGrad<complex64, complex64>(SINH, x_fn);
306}
307
308TEST_F(CWiseUnaryGradTest, Cosh) {
309  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
310  TestCWiseGrad<float, float>(COSH, x_fn);
311}
312
313TEST_F(CWiseUnaryGradTest, Cosh_Complex) {
314  auto x_fn = [this](const int i) {
315    return CRV({{0.5, 0.25}, {0.25, 0.5}, {1.5, -1}, {1, 1.5}});
316  };
317  TestCWiseGrad<complex64, complex64>(COSH, x_fn);
318}
319
320TEST_F(CWiseUnaryGradTest, Tanh) {
321  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
322  TestCWiseGrad<float, float>(TANH, x_fn);
323}
324
325TEST_F(CWiseUnaryGradTest, Tanh_Complex) {
326  auto x_fn = [this](const int i) {
327    return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
328  };
329  TestCWiseGrad<complex64, complex64>(TANH, x_fn);
330}
331
332TEST_F(CWiseUnaryGradTest, Asinh) {
333  auto x_fn = [this](const int i) { return RV({0.5, 1, -1, -1.5, 1.5}); };
334  TestCWiseGrad<float, float>(ASINH, x_fn);
335}
336
337TEST_F(CWiseUnaryGradTest, Asinh_Complex) {
338  auto x_fn = [this](const int i) {
339    return CRV({{1, 0.5}, {0.5, 1}, {0.5, -1}, {1, 1.5}});
340  };
341  TestCWiseGrad<complex64, complex64>(ASINH, x_fn);
342}
343
344TEST_F(CWiseUnaryGradTest, Acosh) {
345  auto x_fn = [this](const int i) { return RV({1.5, 2, 2.5}); };
346  TestCWiseGrad<float, float>(ACOSH, x_fn);
347}
348
349TEST_F(CWiseUnaryGradTest, Acosh_Complex) {
350  auto x_fn = [this](const int i) {
351    return CRV({{1, 0.5}, {0.5, 1}, {0.5, -1}, {1, 1.5}});
352  };
353  TestCWiseGrad<complex64, complex64>(ACOSH, x_fn);
354}
355
356TEST_F(CWiseUnaryGradTest, Atanh) {
357  auto x_fn = [this](const int i) { return RV({0, -0.5, 0.5, -0.1, 0.1}); };
358  TestCWiseGrad<float, float>(ATANH, x_fn);
359}
360
361TEST_F(CWiseUnaryGradTest, Atanh_Complex) {
362  auto x_fn = [this](const int i) {
363    return CRV({{0.1, 0}, {0, 0.1}, {0.2, -0.1}, {0.1, 0.2}, {0.3, 0.4}});
364  };
365  TestCWiseGrad<complex64, complex64>(ATANH, x_fn);
366}
367
368TEST_F(CWiseUnaryGradTest, Sigmoid) {
369  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
370  TestCWiseGrad<float, float>(SIGMOID, x_fn);
371}
372
373TEST_F(CWiseUnaryGradTest, Sigmoid_Complex) {
374  auto x_fn = [this](const int i) {
375    return CRV({{1, 0}, {0, 0}, {2, -1}, {1, 2}, {3, 4}});
376  };
377  TestCWiseGrad<complex64, complex64>(SIGMOID, x_fn);
378}
379
380TEST_F(CWiseUnaryGradTest, Sign) {
381  auto x_fn = [this](const int i) { return RV({-1, 1, -2, 2, -3, 3}); };
382  TestCWiseGrad<float, float>(SIGN, x_fn);
383}
384
385TEST_F(CWiseUnaryGradTest, Sin) {
386  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
387  TestCWiseGrad<float, float>(SIN, x_fn);
388}
389
390TEST_F(CWiseUnaryGradTest, Sin_Complex) {
391  auto x_fn = [this](const int i) {
392    return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}});
393  };
394  TestCWiseGrad<complex64, complex64>(SIN, x_fn);
395}
396
397TEST_F(CWiseUnaryGradTest, Cos) {
398  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
399  TestCWiseGrad<float, float>(COS, x_fn);
400}
401
402TEST_F(CWiseUnaryGradTest, Cos_Complex) {
403  auto x_fn = [this](const int i) {
404    return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}});
405  };
406  TestCWiseGrad<complex64, complex64>(COS, x_fn);
407}
408
409TEST_F(CWiseUnaryGradTest, Asin) {
410  auto x_fn = [this](const int i) { return RV({0, 0.25, -0.25, -0.5, 0.5}); };
411  TestCWiseGrad<float, float>(ASIN, x_fn);
412}
413
414TEST_F(CWiseUnaryGradTest, Asin_Complex) {
415  auto x_fn = [this](const int i) {
416    return CRV({{0.5, 0}, {0, 0.5}, {0.25, -0.75}, {0.5, 0.25}});
417  };
418  // TODO(kbsriram)
419  // Enable test when the asin kernel supports complex numbers
420  if (false) {
421    TestCWiseGrad<complex64, complex64>(ASIN, x_fn);
422  }
423}
424
425TEST_F(CWiseUnaryGradTest, Acos) {
426  auto x_fn = [this](const int i) { return RV({0, -0.5, 0.5, -0.75, 0.75}); };
427  TestCWiseGrad<float, float>(ACOS, x_fn);
428}
429
430TEST_F(CWiseUnaryGradTest, Acos_Complex) {
431  auto x_fn = [this](const int i) {
432    return CRV({{0.5, 0}, {0, 0.5}, {0.25, -0.75}, {0.5, 0.25}});
433  };
434  // TODO(kbsriram)
435  // Add test when the acos kernel supports complex numbers
436  if (false) {
437    TestCWiseGrad<complex64, complex64>(ACOS, x_fn);
438  }
439}
440
441TEST_F(CWiseUnaryGradTest, Tan) {
442  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
443  TestCWiseGrad<float, float>(TAN, x_fn);
444}
445
446TEST_F(CWiseUnaryGradTest, Tan_Complex) {
447  auto x_fn = [this](const int i) {
448    return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
449  };
450  // TODO(kbsriram)
451  // Enable when tan kernel supports complex inputs
452  if (false) {
453    TestCWiseGrad<complex64, complex64>(TAN, x_fn);
454  }
455}
456
457TEST_F(CWiseUnaryGradTest, Atan) {
458  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
459  TestCWiseGrad<float, float>(ATAN, x_fn);
460}
461
462TEST_F(CWiseUnaryGradTest, Atan_Complex) {
463  auto x_fn = [this](const int i) {
464    return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
465  };
466  // TODO(kbsriram)
467  // Add test when the atan kernel supports complex numbers
468  if (false) {
469    TestCWiseGrad<complex64, complex64>(ATAN, x_fn);
470  }
471}
472
473TEST_F(CWiseUnaryGradTest, Real) {
474  auto x_fn = [this](const int i) {
475    return CRV({{1, -1}, {-2, 2}, {2, 3}, {-2, -3}});
476  };
477  TestCWiseGrad<complex64, float>(REAL, x_fn);
478}
479
480TEST_F(CWiseUnaryGradTest, Imag) {
481  auto x_fn = [this](const int i) {
482    return CRV({{1, -1}, {-2, 2}, {2, 3}, {-2, -3}});
483  };
484  TestCWiseGrad<complex64, float>(IMAG, x_fn);
485}
486
487TEST_F(CWiseUnaryGradTest, Conj) {
488  auto x_fn = [this](const int i) {
489    return CRV({{1, -1}, {-2, 2}, {2, 3}, {-2, -3}});
490  };
491  TestCWiseGrad<complex64, complex64>(CONJ, x_fn);
492}
493
494TEST_F(CWiseUnaryGradTest, Complex) {
495  auto x_fn = [this](const int i) { return RV({1, -1, 2, -2, 3, -3}); };
496  TestCWiseGrad<float, complex64>(COMPLEX, x_fn);
497}
498
499TEST_F(CWiseUnaryGradTest, Angle) {
500  auto x_fn = [this](const int i) {
501    return CRV({{1.5, 1.5}, {1.5, -1.5}, {-1.5, 1.5}, {-1.5, -1.5}});
502  };
503  TestCWiseGrad<complex64, float>(ANGLE, x_fn);
504}
505
506class MathGradTest : public ::testing::Test {
507 protected:
508  MathGradTest() : root_(Scope::NewRootScope().WithDevice("/cpu:0")) {}
509
510  template <typename T>
511  void TestMatMulGrad(const bool is_batch, const bool t_x, const bool t_y) {
512    TF_ASSERT_OK(root_.status());
513    // Generate random (but compatible) shapes for matrix multiplication.
514    std::vector<TensorShape> shapes;
515    RandMatMulShapes(is_batch, t_x, t_y, &shapes);
516    TensorShape x_shape = shapes[0];
517    TensorShape y_shape = shapes[1];
518    TensorShape z_shape = shapes[2];
519    auto x =
520        Placeholder(root_, DataTypeToEnum<T>::v(), Placeholder::Shape(x_shape));
521    auto y =
522        Placeholder(root_, DataTypeToEnum<T>::v(), Placeholder::Shape(y_shape));
523    Output z;
524    if (is_batch) {
525      z = BatchMatMul(root_, x, y, BatchMatMul::AdjX(t_x).AdjY(t_y));
526    } else {
527      z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y));
528    }
529
530    float max_error;
531    TF_ASSERT_OK((ComputeGradientError<T, T, float>(
532        root_, {x, y}, {x_shape, y_shape}, {z}, {z_shape}, &max_error)));
533    EXPECT_LT(max_error, 1e-3);
534  }
535
536  void RandMatMulShapes(const bool is_batch, const bool tx, const bool ty,
537                        std::vector<TensorShape>* shapes) {
538    // Choose a random batch size in [1, 4]
539    const int b = 1 + (random::New64() % 4);
540    // z = MatMul(x, y)
541    const int m = Rand();
542    const int k = Rand();
543    const int n = Rand();
544
545    TensorShape x_shape;
546    if (is_batch) {
547      // x.shape = [b, m, k]
548      x_shape = tx ? TensorShape({b, k, m}) : TensorShape({b, m, k});
549    } else {
550      // x.shape = [m, k]
551      x_shape = tx ? TensorShape({k, m}) : TensorShape({m, k});
552    }
553    shapes->push_back(x_shape);
554
555    TensorShape y_shape;
556    if (is_batch) {
557      // y.shape = [b, k, n]
558      y_shape = ty ? TensorShape({b, n, k}) : TensorShape({b, k, n});
559    } else {
560      // y.shape = [k, n]
561      y_shape = ty ? TensorShape({n, k}) : TensorShape({k, n});
562    }
563    shapes->push_back(y_shape);
564
565    TensorShape z_shape;
566    if (is_batch) {
567      // z.shape = [b, m, n]
568      z_shape = TensorShape({b, m, n});
569    } else {
570      // z.shape = [m, n]
571      z_shape = TensorShape({m, n});
572    }
573    shapes->push_back(z_shape);
574  }
575
576  int Rand() { return 1 + (random::New64() % 10); }
577
578  Scope root_;
579};
580
581TEST_F(MathGradTest, MatMulGrad_NoTranspose) {
582  TestMatMulGrad<float>(false, false, false);
583}
584
585TEST_F(MathGradTest, MatMulComplexGrad_NoTranspose) {
586  TestMatMulGrad<complex64>(false, false, false);
587}
588
589TEST_F(MathGradTest, MatMulGrad_TransposeX) {
590  TestMatMulGrad<float>(false, true, false);
591}
592
593TEST_F(MathGradTest, MatMulComplexGrad_TransposeX) {
594  TestMatMulGrad<complex64>(false, true, false);
595}
596
597TEST_F(MathGradTest, MatMulGrad_TransposeY) {
598  TestMatMulGrad<float>(false, false, true);
599}
600
601TEST_F(MathGradTest, MatMulComplexGrad_TransposeY) {
602  TestMatMulGrad<complex64>(false, false, true);
603}
604
605TEST_F(MathGradTest, MatMulGrad_TransposeX_TransposeY) {
606  TestMatMulGrad<float>(false, true, true);
607}
608
609TEST_F(MathGradTest, MatMulComplexGrad_TransposeX_TransposeY) {
610  TestMatMulGrad<complex64>(false, true, true);
611}
612
613TEST_F(MathGradTest, BatchMatMulGrad_NoTranspose) {
614  TestMatMulGrad<float>(true, false, false);
615}
616
617TEST_F(MathGradTest, BatchMatMulComplexGrad_NoTranspose) {
618  TestMatMulGrad<complex64>(true, false, false);
619}
620
621TEST_F(MathGradTest, BatchMatMulGrad_TransposeX) {
622  TestMatMulGrad<float>(true, true, false);
623}
624
625TEST_F(MathGradTest, BatchMatMulComplexGrad_TransposeX) {
626  TestMatMulGrad<complex64>(true, true, false);
627}
628
629TEST_F(MathGradTest, BatchMatMulGrad_TransposeY) {
630  TestMatMulGrad<float>(true, false, true);
631}
632
633TEST_F(MathGradTest, BatchMatMulComplexGrad_TransposeY) {
634  TestMatMulGrad<complex64>(true, false, true);
635}
636
637TEST_F(MathGradTest, BatchMatMulGrad_TransposeX_TransposeY) {
638  TestMatMulGrad<float>(true, true, true);
639}
640
641TEST_F(MathGradTest, BatchMatMulComplexGrad_TransposeX_TransposeY) {
642  TestMatMulGrad<complex64>(true, true, true);
643}
644
645class NaryGradTest : public ::testing::Test {
646 protected:
647  NaryGradTest() : scope_(Scope::NewRootScope().WithDevice("/cpu:0")) {}
648
649  void RunTest(const OutputList& xs, const std::vector<TensorShape>& x_shapes,
650               const OutputList& ys, const std::vector<TensorShape>& y_shapes) {
651    TF_ASSERT_OK(scope_.status());
652    float max_error;
653    TF_ASSERT_OK((ComputeGradientError<float, float, float>(
654        scope_, xs, x_shapes, ys, y_shapes, &max_error)));
655    EXPECT_LT(max_error, 1e-3);
656  }
657
658  void RunTest(const Output& x, const Tensor& x_init_value, const Output& y,
659               const TensorShape& y_shape) {
660    TF_ASSERT_OK(scope_.status());
661    float max_error;
662    TF_ASSERT_OK((ComputeGradientError<float, float, float>(
663        scope_, x, x_init_value, y, y_shape, &max_error)));
664    EXPECT_LT(max_error, 1e-3);
665  }
666
667  Scope scope_;
668};
669
670TEST_F(NaryGradTest, Sum) {
671  TensorShape x_shape({2, 3, 5, 7});
672  auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
673  auto y = Sum(scope_, x, {1, -1});
674  // y's shape is the result of reducing x along axes 1 and -1 (= 3)
675  TensorShape y_shape({2, 5});
676  RunTest({x}, {x_shape}, {y}, {y_shape});
677}
678
679TEST_F(NaryGradTest, Mean) {
680  TensorShape x_shape({2, 3, 5, 7});
681  auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
682  auto y = Mean(scope_, x, {1, -1});
683  // y's shape is the result of reducing x along axes 1 and -1 (= 3)
684  TensorShape y_shape({2, 5});
685  RunTest({x}, {x_shape}, {y}, {y_shape});
686}
687
688TEST_F(NaryGradTest, Min) {
689  TensorShape x_shape({2, 3});
690  auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
691  auto y = Min(scope_, x, {-1});
692  // y's shape is the result of reducing x along axes -1 (= 1)
693  TensorShape y_shape({2});
694  Tensor x_init_value =
695      test::AsTensor<float>({0.5f, 0.7f, 0.2f, 1.0f, 1.5f, -2.8f}, x_shape);
696  RunTest(x, x_init_value, y, y_shape);
697}
698
699TEST_F(NaryGradTest, Max) {
700  TensorShape x_shape({2, 3});
701  auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
702  auto y = Max(scope_, x, {-1});
703  // y's shape is the result of reducing x along axes -1 (= 1)
704  TensorShape y_shape({2});
705  Tensor x_init_value =
706      test::AsTensor<float>({0.5f, 0.7f, 0.2f, 1.0f, 1.5f, -2.8f}, x_shape);
707  RunTest(x, x_init_value, y, y_shape);
708}
709
710TEST_F(NaryGradTest, MinMulti) {
711  // Test gradient when there are multiple minima.
712  // Note that we cannot directly use a test Tensor with multiple
713  // minima, as the numeric estimator will calculate incorrect
714  // gradients when perturbing each entry in the Tensor (which then
715  // changes how many minima exist.)
716  // Instead, we use a single input that broadcast-multiplies a larger
717  // tensor with equal values, and apply reduce_min to the multiplied
718  // result.
719  TensorShape x_shape({1});
720  auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
721  auto all_same = Mul(scope_, Const(scope_, {1.f, 1.f, 1.f}), x);
722  auto y = Min(scope_, all_same, {0});
723  // y is a [3] shaped tensor reduced along dimension 0, so it is [1] shaped
724  TensorShape y_shape({1});
725  RunTest({x}, {x_shape}, {y}, {y_shape});
726}
727
728TEST_F(NaryGradTest, MaxMulti) {
729  TensorShape x_shape({1});
730  auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
731  auto all_same = Mul(scope_, Const(scope_, {1.f, 1.f, 1.f}), x);
732  auto y = Max(scope_, all_same, {0});
733  TensorShape y_shape({1});
734  RunTest({x}, {x_shape}, {y}, {y_shape});
735}
736
737TEST_F(NaryGradTest, AddN) {
738  TensorShape shape({3, 2, 5});
739  std::vector<Output> xs;
740  xs.push_back(Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)));
741  xs.push_back(Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)));
742  xs.push_back(Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)));
743  auto y = AddN(scope_, xs);
744  RunTest(xs, {shape, shape, shape}, {y}, {shape});
745}
746
747TEST_F(NaryGradTest, Add) {
748  TensorShape x1_shape({3, 2, 5});
749  TensorShape x2_shape({2, 5});
750  auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x1_shape));
751  auto x2 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x2_shape));
752  auto y = Add(scope_, x1, x2);
753  RunTest({x1, x2}, {x1_shape, x2_shape}, {y}, {x1_shape});
754}
755
756TEST_F(NaryGradTest, Sub) {
757  TensorShape x1_shape({3, 2, 5});
758  TensorShape x2_shape({2, 5});
759  auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x1_shape));
760  auto x2 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x2_shape));
761  auto y = Sub(scope_, x1, x2);
762  RunTest({x1, x2}, {x1_shape, x2_shape}, {y}, {x1_shape});
763}
764
765TEST_F(NaryGradTest, Mul) {
766  TensorShape x1_shape({3, 2, 5});
767  TensorShape x2_shape({2, 5});
768  auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x1_shape));
769  auto x2 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x2_shape));
770  auto y = Mul(scope_, x1, x2);
771  RunTest({x1, x2}, {x1_shape, x2_shape}, {y}, {x1_shape});
772}
773
774TEST_F(NaryGradTest, Div) {
775  TensorShape x_shape({3, 2, 5});
776  auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
777  // Test x / (1 + |x|) rather than x_1 / x_2 to avoid triggering large
778  // division errors in the numeric estimator used by the gradient checker.
779  auto y = Div(scope_, x, Add(scope_, Const<float>(scope_, 1), Abs(scope_, x)));
780  RunTest({x}, {x_shape}, {y}, {x_shape});
781}
782
783TEST_F(NaryGradTest, RealDiv) {
784  TensorShape x_shape({3, 2, 5});
785  auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
786  // Test x / (1 + |x|) rather than x_1 / x_2 to avoid triggering large
787  // division errors in the numeric estimator used by the gradient checker.
788  auto y =
789      RealDiv(scope_, x, Add(scope_, Const<float>(scope_, 1), Abs(scope_, x)));
790  RunTest({x}, {x_shape}, {y}, {x_shape});
791}
792
793TEST_F(NaryGradTest, SquaredDifference) {
794  TensorShape x1_shape({3, 2, 5});
795  TensorShape x2_shape({2, 5});
796  auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x1_shape));
797  auto x2 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x2_shape));
798  auto y = SquaredDifference(scope_, x1, x2);
799  RunTest({x1, x2}, {x1_shape, x2_shape}, {y}, {x1_shape});
800}
801
802TEST_F(NaryGradTest, Maximum) {
803  TensorShape shape({3, 2});
804  auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
805  auto y = Maximum(scope_, x, Const(scope_, 1.0f));
806  // Select values away from 1.0f to avoid instability when computing
807  // finite differences.
808  Tensor x_init_value =
809      test::AsTensor<float>({0.5f, 1.5f, -1.2f, 3.0f, 0.1f, 2.8f}, {3, 2});
810  RunTest(x, x_init_value, y, shape);
811}
812
813TEST_F(NaryGradTest, Minimum) {
814  TensorShape shape({3, 2});
815  auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
816  auto y = Minimum(scope_, x, Const(scope_, 1.0f));
817  // Select values away from 1.0f to avoid instability when computing
818  // finite differences.
819  Tensor x_init_value =
820      test::AsTensor<float>({0.5f, 1.5f, -1.2f, 3.0f, 0.1f, 2.8f}, {3, 2});
821  RunTest(x, x_init_value, y, shape);
822}
823
824TEST_F(NaryGradTest, Lgamma) {
825  TensorShape shape({3, 2});
826  auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
827  auto y = Lgamma(scope_, x);
828  // Select values to avoid instability when computing finite differences.
829  // Ref: https://en.wikipedia.org/wiki/File:Gamma_plot.svg
830  Tensor x_init_value =
831      test::AsTensor<float>({-3.5f, -2.5f, -1.5f, 1.0f, 2.0f, 3.5f}, {3, 2});
832  RunTest(x, x_init_value, y, shape);
833  // TODO(suharshs): add test case for complex values
834}
835
836}  // namespace
837}  // namespace tensorflow
838