1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/cc/framework/grad_op_registry.h"
17#include "tensorflow/cc/framework/gradient_checker.h"
18#include "tensorflow/cc/framework/testutil.h"
19#include "tensorflow/cc/gradients/grad_testutil.h"
20#include "tensorflow/cc/ops/standard_ops.h"
21#include "tensorflow/core/framework/tensor_testutil.h"
22#include "tensorflow/core/lib/core/status_test_util.h"
23#include "tensorflow/core/lib/random/random.h"
24
25namespace tensorflow {
26namespace {
27
28using ops::Abs;
29using ops::Add;
30using ops::AddN;
31using ops::BatchMatMul;
32using ops::Const;
33using ops::Div;
34using ops::Greater;
35using ops::MatMul;
36using ops::Max;
37using ops::Maximum;
38using ops::Mean;
39using ops::Min;
40using ops::Minimum;
41using ops::Mul;
42using ops::Placeholder;
43using ops::Pow;
44using ops::Prod;
45using ops::RealDiv;
46using ops::SquaredDifference;
47using ops::Sub;
48using ops::Sum;
49using ops::Where3;
50
51// TODO(andydavis) Test gradient function against numeric gradients output.
52// TODO(andydavis) As more gradients are added move common test functions
53// to a testutil library.
54
55class CWiseUnaryGradTest : public ::testing::Test {
56 protected:
57  CWiseUnaryGradTest() : scope_(Scope::NewRootScope().WithDevice("/cpu:0")) {}
58
59  enum UnaryOpType {
60    ABS,
61    NEG,
62    INV,
63    SQUARE,
64    SQRT,
65    RSQRT,
66    EXP,
67    EXPM1,
68    LOG,
69    LOG1P,
70    SINH,
71    COSH,
72    TANH,
73    ASINH,
74    ACOSH,
75    ATANH,
76    SIGMOID,
77    SIGN,
78    SIN,
79    COS,
80    ASIN,
81    ACOS,
82    TAN,
83    ATAN,
84    REAL,
85    IMAG,
86    CONJ,
87    COMPLEX,
88    ANGLE,
89    LGAMMA,
90    ERF
91  };
92
93  template <typename X_T, typename Y_T>
94  void TestCWiseGrad(UnaryOpType op_type, const std::function<X_T(int)>& x_fn) {
95    TF_ASSERT_OK(scope_.status());
96    DataType x_type = DataTypeToEnum<X_T>::v();
97    TensorShape shape({2, 3, 2});
98    auto x = Placeholder(scope_, x_type, Placeholder::Shape(shape));
99    Tensor x_data(x_type, shape);
100    auto x_data_flat = x_data.flat<X_T>();
101    for (int i = 0; i < x_data_flat.size(); ++i) {
102      x_data_flat(i) = x_fn(i);
103    }
104
105    Output y;
106    switch (op_type) {
107      using namespace ops;  // NOLINT(build/namespaces)
108      case ABS:
109        y = Abs(scope_, x);
110        break;
111      case NEG:
112        y = Neg(scope_, x);
113        break;
114      case INV:
115        y = Reciprocal(scope_, x);
116        break;
117      case SQUARE:
118        y = Square(scope_, x);
119        break;
120      case SQRT:
121        y = Sqrt(scope_, x);
122        break;
123      case RSQRT:
124        y = Rsqrt(scope_, x);
125        break;
126      case EXP:
127        y = Exp(scope_, x);
128        break;
129      case EXPM1:
130        y = Expm1(scope_, x);
131        break;
132      case LOG:
133        y = Log(scope_, x);
134        break;
135      case LOG1P:
136        y = Log1p(scope_, x);
137        break;
138      case SINH:
139        y = Sinh(scope_, x);
140        break;
141      case COSH:
142        y = Cosh(scope_, x);
143        break;
144      case TANH:
145        y = Tanh(scope_, x);
146        break;
147      case ASINH:
148        y = Asinh(scope_, x);
149        break;
150      case ACOSH:
151        y = Acosh(scope_, x);
152        break;
153      case ATANH:
154        y = Atanh(scope_, x);
155        break;
156      case SIGMOID:
157        y = Sigmoid(scope_, x);
158        break;
159      case SIGN:
160        y = Sign(scope_, x);
161        break;
162      case SIN:
163        y = Sin(scope_, x);
164        break;
165      case COS:
166        y = Cos(scope_, x);
167        break;
168      case ASIN:
169        y = Asin(scope_, x);
170        break;
171      case ACOS:
172        y = Acos(scope_, x);
173        break;
174      case TAN:
175        y = Tan(scope_, x);
176        break;
177      case ATAN:
178        y = Atan(scope_, x);
179        break;
180      case REAL:
181        y = Real(scope_, x);
182        break;
183      case IMAG:
184        y = Imag(scope_, x);
185        break;
186      case CONJ:
187        y = Conj(scope_, x);
188        break;
189      case COMPLEX:
190        y = Complex(scope_, x, x);
191        break;
192      case ANGLE:
193        y = Angle(scope_, x);
194        break;
195      case LGAMMA:
196        y = Lgamma(scope_, x);
197        break;
198      case ERF:
199        y = Erf(scope_, x);
200        break;
201    }
202
203    float max_error;
204    TF_ASSERT_OK((ComputeGradientError<X_T, Y_T, float>(scope_, x, x_data, y,
205                                                        shape, &max_error)));
206    EXPECT_LT(max_error, 1e-3f);
207  }
208
209  float RV(const std::vector<float>& v) {
210    return v[random::New64() % v.size()];
211  }
212
213  complex64 CRV(const std::vector<complex64>& v) {
214    return v[random::New64() % v.size()];
215  }
216
217  complex64 conjugate(const complex64& val) {
218    return complex64(val.real(), -val.imag());
219  }
220
221  Scope scope_;
222};
223
224TEST_F(CWiseUnaryGradTest, Abs) {
225  auto x_fn = [this](const int i) { return RV({-1, 0, 1}); };
226  TestCWiseGrad<float, float>(ABS, x_fn);
227}
228
229TEST_F(CWiseUnaryGradTest, Neg) {
230  auto x_fn = [this](const int i) { return RV({-1, 0, 1}); };
231  TestCWiseGrad<float, float>(NEG, x_fn);
232}
233
234TEST_F(CWiseUnaryGradTest, Reciprocal) {
235  auto x_fn = [this](const int i) { return RV({-1, 1, -2, 2, -3, 3, -4, 4}); };
236  TestCWiseGrad<float, float>(INV, x_fn);
237}
238
239TEST_F(CWiseUnaryGradTest, Reciprocal_Complex) {
240  auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); };
241  TestCWiseGrad<complex64, complex64>(INV, x_fn);
242}
243
244TEST_F(CWiseUnaryGradTest, Square) {
245  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
246  TestCWiseGrad<float, float>(SQUARE, x_fn);
247}
248
249TEST_F(CWiseUnaryGradTest, Square_Complex) {
250  auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); };
251  TestCWiseGrad<complex64, complex64>(SQUARE, x_fn);
252}
253
254TEST_F(CWiseUnaryGradTest, Sqrt) {
255  auto x_fn = [this](const int i) { return RV({0.5, 1, 2, 3, 4, 5, 6, 7}); };
256  TestCWiseGrad<float, float>(SQRT, x_fn);
257}
258
259TEST_F(CWiseUnaryGradTest, Sqrt_Complex) {
260  auto x_fn = [this](const int i) {
261    return CRV({{-1.0f, 0.5f}, {1.0f, 0.5f}, {2, -1}});
262  };
263  TestCWiseGrad<complex64, complex64>(SQRT, x_fn);
264}
265
266TEST_F(CWiseUnaryGradTest, Rsqrt) {
267  auto x_fn = [this](const int i) { return RV({1, 2, 3, 4, 5, 6, 7, 8}); };
268  TestCWiseGrad<float, float>(RSQRT, x_fn);
269}
270
271TEST_F(CWiseUnaryGradTest, Rsqrt_Complex) {
272  auto x_fn = [this](const int i) {
273    return CRV({{-1.0f, 0.5f}, {1.0f, 0.5f}, {2, -1}});
274  };
275  TestCWiseGrad<complex64, complex64>(RSQRT, x_fn);
276}
277
278TEST_F(CWiseUnaryGradTest, Exp) {
279  auto x_fn = [this](const int i) {
280    return RV({0, -1, 1, -1.5f, 1.5f, -2, 2});
281  };
282  TestCWiseGrad<float, float>(EXP, x_fn);
283}
284
285TEST_F(CWiseUnaryGradTest, Exp_Complex) {
286  auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); };
287  TestCWiseGrad<complex64, complex64>(EXP, x_fn);
288}
289
290TEST_F(CWiseUnaryGradTest, Expm1) {
291  auto x_fn = [this](const int i) { return RV({0, -1, 1e-6, 1, -1.5, 1.5}); };
292  TestCWiseGrad<float, float>(EXPM1, x_fn);
293}
294
295TEST_F(CWiseUnaryGradTest, Expm1_Complex) {
296  auto x_fn = [this](const int i) {
297    return CRV({{-1, 0}, {1, 0}, {1.5, -1.5}});
298  };
299  TestCWiseGrad<complex64, complex64>(EXPM1, x_fn);
300}
301
302TEST_F(CWiseUnaryGradTest, Log) {
303  auto x_fn = [this](const int i) { return RV({0.5, 1, 2, 3, 4}); };
304  TestCWiseGrad<float, float>(LOG, x_fn);
305}
306
307TEST_F(CWiseUnaryGradTest, Log_Complex) {
308  auto x_fn = [this](const int i) {
309    return CRV({{-1, 0.5f}, {1, 0.5f}, {2, -1}});
310  };
311  TestCWiseGrad<complex64, complex64>(LOG, x_fn);
312}
313
314TEST_F(CWiseUnaryGradTest, Log1p) {
315  auto x_fn = [this](const int i) { return RV({0, 1e-6, 1, 2, 3, 4, 100}); };
316  TestCWiseGrad<float, float>(LOG1P, x_fn);
317}
318
319TEST_F(CWiseUnaryGradTest, Log1p_Complex) {
320  auto x_fn = [this](const int i) {
321    return CRV({{0, 0}, {1e-6, 0}, {2, -1}, {1, 2}, {3, 4}});
322  };
323  TestCWiseGrad<complex64, complex64>(LOG1P, x_fn);
324}
325
326TEST_F(CWiseUnaryGradTest, Sinh) {
327  auto x_fn = [this](const int i) { return RV({0.5, -0.5, 1, -1, 1.5, -1.5}); };
328  TestCWiseGrad<float, float>(SINH, x_fn);
329}
330
331TEST_F(CWiseUnaryGradTest, Sinh_Complex) {
332  auto x_fn = [this](const int i) {
333    return CRV({{0.5, 0.25}, {0.25, 0.5}, {1.5, -1}, {1, 1.5}});
334  };
335  TestCWiseGrad<complex64, complex64>(SINH, x_fn);
336}
337
338TEST_F(CWiseUnaryGradTest, Cosh) {
339  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
340  TestCWiseGrad<float, float>(COSH, x_fn);
341}
342
343TEST_F(CWiseUnaryGradTest, Cosh_Complex) {
344  auto x_fn = [this](const int i) {
345    return CRV({{0.5, 0.25}, {0.25, 0.5}, {1.5, -1}, {1, 1.5}});
346  };
347  TestCWiseGrad<complex64, complex64>(COSH, x_fn);
348}
349
350TEST_F(CWiseUnaryGradTest, Tanh) {
351  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
352  TestCWiseGrad<float, float>(TANH, x_fn);
353}
354
355TEST_F(CWiseUnaryGradTest, Tanh_Complex) {
356  auto x_fn = [this](const int i) {
357    return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
358  };
359  TestCWiseGrad<complex64, complex64>(TANH, x_fn);
360}
361
362TEST_F(CWiseUnaryGradTest, Asinh) {
363  auto x_fn = [this](const int i) { return RV({0.5, 1, -1, -1.5, 1.5}); };
364  TestCWiseGrad<float, float>(ASINH, x_fn);
365}
366
367TEST_F(CWiseUnaryGradTest, Asinh_Complex) {
368  auto x_fn = [this](const int i) {
369    return CRV({{1, 0.5}, {0.5, 1}, {0.5, -1}, {1, 1.5}});
370  };
371  TestCWiseGrad<complex64, complex64>(ASINH, x_fn);
372}
373
374TEST_F(CWiseUnaryGradTest, Acosh) {
375  auto x_fn = [this](const int i) { return RV({1.5, 2, 2.5}); };
376  TestCWiseGrad<float, float>(ACOSH, x_fn);
377}
378
379TEST_F(CWiseUnaryGradTest, Acosh_Complex) {
380  auto x_fn = [this](const int i) {
381    return CRV({{1, 0.5}, {0.5, 1}, {0.5, -1}, {1, 1.5}});
382  };
383  TestCWiseGrad<complex64, complex64>(ACOSH, x_fn);
384}
385
386TEST_F(CWiseUnaryGradTest, Atanh) {
387  auto x_fn = [this](const int i) { return RV({0, -0.5, 0.5, -0.1, 0.1}); };
388  TestCWiseGrad<float, float>(ATANH, x_fn);
389}
390
391TEST_F(CWiseUnaryGradTest, Atanh_Complex) {
392  auto x_fn = [this](const int i) {
393    return CRV({{0.1, 0}, {0, 0.1}, {0.2, -0.1}, {0.1, 0.2}, {0.3, 0.4}});
394  };
395  TestCWiseGrad<complex64, complex64>(ATANH, x_fn);
396}
397
398TEST_F(CWiseUnaryGradTest, Sigmoid) {
399  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
400  TestCWiseGrad<float, float>(SIGMOID, x_fn);
401}
402
403TEST_F(CWiseUnaryGradTest, Sigmoid_Complex) {
404  auto x_fn = [this](const int i) {
405    return CRV({{1, 0}, {0, 0}, {2, -1}, {1, 2}, {3, 4}});
406  };
407  TestCWiseGrad<complex64, complex64>(SIGMOID, x_fn);
408}
409
410TEST_F(CWiseUnaryGradTest, Sign) {
411  auto x_fn = [this](const int i) { return RV({-1, 1, -2, 2, -3, 3}); };
412  TestCWiseGrad<float, float>(SIGN, x_fn);
413}
414
415TEST_F(CWiseUnaryGradTest, Sin) {
416  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
417  TestCWiseGrad<float, float>(SIN, x_fn);
418}
419
420TEST_F(CWiseUnaryGradTest, Sin_Complex) {
421  auto x_fn = [this](const int i) {
422    return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}});
423  };
424  TestCWiseGrad<complex64, complex64>(SIN, x_fn);
425}
426
427TEST_F(CWiseUnaryGradTest, Cos) {
428  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
429  TestCWiseGrad<float, float>(COS, x_fn);
430}
431
432TEST_F(CWiseUnaryGradTest, Cos_Complex) {
433  auto x_fn = [this](const int i) {
434    return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}});
435  };
436  TestCWiseGrad<complex64, complex64>(COS, x_fn);
437}
438
439TEST_F(CWiseUnaryGradTest, Asin) {
440  auto x_fn = [this](const int i) { return RV({0, 0.25, -0.25, -0.5, 0.5}); };
441  TestCWiseGrad<float, float>(ASIN, x_fn);
442}
443
444TEST_F(CWiseUnaryGradTest, Asin_Complex) {
445  auto x_fn = [this](const int i) {
446    return CRV({{0.5, 0}, {0, 0.5}, {0.25, -0.75}, {0.5, 0.25}});
447  };
448  // TODO(kbsriram)
449  // Enable test when the asin kernel supports complex numbers
450  if (false) {
451    TestCWiseGrad<complex64, complex64>(ASIN, x_fn);
452  }
453}
454
455TEST_F(CWiseUnaryGradTest, Acos) {
456  auto x_fn = [this](const int i) { return RV({0, -0.5, 0.5, -0.75, 0.75}); };
457  TestCWiseGrad<float, float>(ACOS, x_fn);
458}
459
460TEST_F(CWiseUnaryGradTest, Acos_Complex) {
461  auto x_fn = [this](const int i) {
462    return CRV({{0.5, 0}, {0, 0.5}, {0.25, -0.75}, {0.5, 0.25}});
463  };
464  // TODO(kbsriram)
465  // Add test when the acos kernel supports complex numbers
466  if (false) {
467    TestCWiseGrad<complex64, complex64>(ACOS, x_fn);
468  }
469}
470
471TEST_F(CWiseUnaryGradTest, Tan) {
472  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
473  TestCWiseGrad<float, float>(TAN, x_fn);
474}
475
476TEST_F(CWiseUnaryGradTest, Tan_Complex) {
477  auto x_fn = [this](const int i) {
478    return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
479  };
480  // TODO(kbsriram)
481  // Enable when tan kernel supports complex inputs
482  if (false) {
483    TestCWiseGrad<complex64, complex64>(TAN, x_fn);
484  }
485}
486
487TEST_F(CWiseUnaryGradTest, Atan) {
488  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
489  TestCWiseGrad<float, float>(ATAN, x_fn);
490}
491
492TEST_F(CWiseUnaryGradTest, Atan_Complex) {
493  auto x_fn = [this](const int i) {
494    return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
495  };
496  // TODO(kbsriram)
497  // Add test when the atan kernel supports complex numbers
498  if (false) {
499    TestCWiseGrad<complex64, complex64>(ATAN, x_fn);
500  }
501}
502
503TEST_F(CWiseUnaryGradTest, Real) {
504  auto x_fn = [this](const int i) {
505    return CRV({{1, -1}, {-2, 2}, {2, 3}, {-2, -3}});
506  };
507  TestCWiseGrad<complex64, float>(REAL, x_fn);
508}
509
510TEST_F(CWiseUnaryGradTest, Imag) {
511  auto x_fn = [this](const int i) {
512    return CRV({{1, -1}, {-2, 2}, {2, 3}, {-2, -3}});
513  };
514  TestCWiseGrad<complex64, float>(IMAG, x_fn);
515}
516
517TEST_F(CWiseUnaryGradTest, Conj) {
518  auto x_fn = [this](const int i) {
519    return CRV({{1, -1}, {-2, 2}, {2, 3}, {-2, -3}});
520  };
521  TestCWiseGrad<complex64, complex64>(CONJ, x_fn);
522}
523
524TEST_F(CWiseUnaryGradTest, Complex) {
525  auto x_fn = [this](const int i) { return RV({1, -1, 2, -2, 3, -3}); };
526  TestCWiseGrad<float, complex64>(COMPLEX, x_fn);
527}
528
529TEST_F(CWiseUnaryGradTest, Angle) {
530  auto x_fn = [this](const int i) {
531    return CRV({{1.5, 1.5}, {1.5, -1.5}, {-1.5, 1.5}, {-1.5, -1.5}});
532  };
533  TestCWiseGrad<complex64, float>(ANGLE, x_fn);
534}
535
536TEST_F(CWiseUnaryGradTest, Lgamma) {
537  auto x_fn = [this](const int i) {
538    return RV({-3.5, -2.5, -1.5, 1.0, 2.0, 3.5});
539  };
540  TestCWiseGrad<float, float>(LGAMMA, x_fn);
541}
542
543TEST_F(CWiseUnaryGradTest, Lgamma_Complex) {
544  auto x_fn = [this](const int i) {
545    return CRV({{-3.5, 0.5}, {-1.5, -0.5}, {1.5, -1.0}, {3.5, 1.0}});
546  };
547  // TODO(kbsriram)
548  // Add test when the lgamma kernel supports complex numbers
549  if (false) {
550    TestCWiseGrad<complex64, complex64>(LGAMMA, x_fn);
551  }
552}
553
554TEST_F(CWiseUnaryGradTest, Erf) {
555  auto x_fn = [this](const int i) {
556    return RV({-1.2, -1.0, -0.5, 0.3, 0.5, 1.3});
557  };
558  TestCWiseGrad<float, float>(ERF, x_fn);
559}
560
561TEST_F(CWiseUnaryGradTest, Erf_Complex) {
562  auto x_fn = [this](const int i) {
563    return CRV({{-1.2, 0.5}, {-0.5, -0.5}, {0.5, 0.5}, {1.2, -0.5}});
564  };
565  // TODO(kbsriram)
566  // Add test when the erf kernel supports complex numbers
567  if (false) {
568    TestCWiseGrad<complex64, complex64>(ERF, x_fn);
569  }
570}
571
572class MathGradTest : public ::testing::Test {
573 protected:
574  MathGradTest() : root_(Scope::NewRootScope().WithDevice("/cpu:0")) {}
575
576  template <typename T>
577  void TestMatMulGrad(const bool is_batch, const bool t_x, const bool t_y) {
578    TF_ASSERT_OK(root_.status());
579    // Generate random (but compatible) shapes for matrix multiplication.
580    std::vector<TensorShape> shapes;
581    RandMatMulShapes(is_batch, t_x, t_y, &shapes);
582    TensorShape x_shape = shapes[0];
583    TensorShape y_shape = shapes[1];
584    TensorShape z_shape = shapes[2];
585    auto x =
586        Placeholder(root_, DataTypeToEnum<T>::v(), Placeholder::Shape(x_shape));
587    auto y =
588        Placeholder(root_, DataTypeToEnum<T>::v(), Placeholder::Shape(y_shape));
589    Output z;
590    if (is_batch) {
591      z = BatchMatMul(root_, x, y, BatchMatMul::AdjX(t_x).AdjY(t_y));
592    } else {
593      z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y));
594    }
595
596    float max_error;
597    TF_ASSERT_OK((ComputeGradientError<T, T, float>(
598        root_, {x, y}, {x_shape, y_shape}, {z}, {z_shape}, &max_error)));
599    EXPECT_LT(max_error, 1e-3);
600  }
601
602  void RandMatMulShapes(const bool is_batch, const bool tx, const bool ty,
603                        std::vector<TensorShape>* shapes) {
604    // Choose a random batch size in [1, 4]
605    const int b = 1 + (random::New64() % 4);
606    // z = MatMul(x, y)
607    const int m = Rand();
608    const int k = Rand();
609    const int n = Rand();
610
611    TensorShape x_shape;
612    if (is_batch) {
613      // x.shape = [b, m, k]
614      x_shape = tx ? TensorShape({b, k, m}) : TensorShape({b, m, k});
615    } else {
616      // x.shape = [m, k]
617      x_shape = tx ? TensorShape({k, m}) : TensorShape({m, k});
618    }
619    shapes->push_back(x_shape);
620
621    TensorShape y_shape;
622    if (is_batch) {
623      // y.shape = [b, k, n]
624      y_shape = ty ? TensorShape({b, n, k}) : TensorShape({b, k, n});
625    } else {
626      // y.shape = [k, n]
627      y_shape = ty ? TensorShape({n, k}) : TensorShape({k, n});
628    }
629    shapes->push_back(y_shape);
630
631    TensorShape z_shape;
632    if (is_batch) {
633      // z.shape = [b, m, n]
634      z_shape = TensorShape({b, m, n});
635    } else {
636      // z.shape = [m, n]
637      z_shape = TensorShape({m, n});
638    }
639    shapes->push_back(z_shape);
640  }
641
642  int Rand() { return 1 + (random::New64() % 10); }
643
644  Scope root_;
645};
646
647TEST_F(MathGradTest, MatMulGrad_NoTranspose) {
648  TestMatMulGrad<float>(false, false, false);
649}
650
651TEST_F(MathGradTest, MatMulComplexGrad_NoTranspose) {
652  TestMatMulGrad<complex64>(false, false, false);
653}
654
655TEST_F(MathGradTest, MatMulGrad_TransposeX) {
656  TestMatMulGrad<float>(false, true, false);
657}
658
659TEST_F(MathGradTest, MatMulComplexGrad_TransposeX) {
660  TestMatMulGrad<complex64>(false, true, false);
661}
662
663TEST_F(MathGradTest, MatMulGrad_TransposeY) {
664  TestMatMulGrad<float>(false, false, true);
665}
666
667TEST_F(MathGradTest, MatMulComplexGrad_TransposeY) {
668  TestMatMulGrad<complex64>(false, false, true);
669}
670
671TEST_F(MathGradTest, MatMulGrad_TransposeX_TransposeY) {
672  TestMatMulGrad<float>(false, true, true);
673}
674
675TEST_F(MathGradTest, MatMulComplexGrad_TransposeX_TransposeY) {
676  TestMatMulGrad<complex64>(false, true, true);
677}
678
679TEST_F(MathGradTest, BatchMatMulGrad_NoTranspose) {
680  TestMatMulGrad<float>(true, false, false);
681}
682
683TEST_F(MathGradTest, BatchMatMulComplexGrad_NoTranspose) {
684  TestMatMulGrad<complex64>(true, false, false);
685}
686
687TEST_F(MathGradTest, BatchMatMulGrad_TransposeX) {
688  TestMatMulGrad<float>(true, true, false);
689}
690
691TEST_F(MathGradTest, BatchMatMulComplexGrad_TransposeX) {
692  TestMatMulGrad<complex64>(true, true, false);
693}
694
695TEST_F(MathGradTest, BatchMatMulGrad_TransposeY) {
696  TestMatMulGrad<float>(true, false, true);
697}
698
699TEST_F(MathGradTest, BatchMatMulComplexGrad_TransposeY) {
700  TestMatMulGrad<complex64>(true, false, true);
701}
702
703TEST_F(MathGradTest, BatchMatMulGrad_TransposeX_TransposeY) {
704  TestMatMulGrad<float>(true, true, true);
705}
706
707TEST_F(MathGradTest, BatchMatMulComplexGrad_TransposeX_TransposeY) {
708  TestMatMulGrad<complex64>(true, true, true);
709}
710
711class NaryGradTest : public ::testing::Test {
712 protected:
713  NaryGradTest() : scope_(Scope::NewRootScope().WithDevice("/cpu:0")) {}
714
715  void RunTest(const OutputList& xs, const std::vector<TensorShape>& x_shapes,
716               const OutputList& ys, const std::vector<TensorShape>& y_shapes) {
717    TF_ASSERT_OK(scope_.status());
718    float max_error;
719    TF_ASSERT_OK((ComputeGradientError<float, float, float>(
720        scope_, xs, x_shapes, ys, y_shapes, &max_error)));
721    EXPECT_LT(max_error, 1e-3);
722  }
723
724  void RunTest(const Output& x, const Tensor& x_init_value, const Output& y,
725               const TensorShape& y_shape) {
726    TF_ASSERT_OK(scope_.status());
727    float max_error;
728    TF_ASSERT_OK((ComputeGradientError<float, float, float>(
729        scope_, x, x_init_value, y, y_shape, &max_error)));
730    EXPECT_LT(max_error, 1e-3);
731  }
732
733  Scope scope_;
734};
735
736TEST_F(NaryGradTest, Sum) {
737  TensorShape x_shape({2, 3, 5, 7});
738  auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
739  auto y = Sum(scope_, x, {1, -1});
740  // y's shape is the result of reducing x along axes 1 and -1 (= 3)
741  TensorShape y_shape({2, 5});
742  RunTest({x}, {x_shape}, {y}, {y_shape});
743}
744
745TEST_F(NaryGradTest, Mean) {
746  TensorShape x_shape({2, 3, 5, 7});
747  auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
748  auto y = Mean(scope_, x, {1, -1});
749  // y's shape is the result of reducing x along axes 1 and -1 (= 3)
750  TensorShape y_shape({2, 5});
751  RunTest({x}, {x_shape}, {y}, {y_shape});
752}
753
754TEST_F(NaryGradTest, Min) {
755  TensorShape x_shape({2, 3});
756  auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
757  auto y = Min(scope_, x, {-1});
758  // y's shape is the result of reducing x along axes -1 (= 1)
759  TensorShape y_shape({2});
760  Tensor x_init_value =
761      test::AsTensor<float>({0.5f, 0.7f, 0.2f, 1.0f, 1.5f, -2.8f}, x_shape);
762  RunTest(x, x_init_value, y, y_shape);
763}
764
765TEST_F(NaryGradTest, Max) {
766  TensorShape x_shape({2, 3});
767  auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
768  auto y = Max(scope_, x, {-1});
769  // y's shape is the result of reducing x along axes -1 (= 1)
770  TensorShape y_shape({2});
771  Tensor x_init_value =
772      test::AsTensor<float>({0.5f, 0.7f, 0.2f, 1.0f, 1.5f, -2.8f}, x_shape);
773  RunTest(x, x_init_value, y, y_shape);
774}
775
776TEST_F(NaryGradTest, MinMulti) {
777  // Test gradient when there are multiple minima.
778  // Note that we cannot directly use a test Tensor with multiple
779  // minima, as the numeric estimator will calculate incorrect
780  // gradients when perturbing each entry in the Tensor (which then
781  // changes how many minima exist.)
782  // Instead, we use a single input that broadcast-multiplies a larger
783  // tensor with equal values, and apply reduce_min to the multiplied
784  // result.
785  TensorShape x_shape({1});
786  auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
787  auto all_same = Mul(scope_, Const(scope_, {1.f, 1.f, 1.f}), x);
788  auto y = Min(scope_, all_same, {0});
789  // y is a [3] shaped tensor reduced along dimension 0, so it is [1] shaped
790  TensorShape y_shape({1});
791  RunTest({x}, {x_shape}, {y}, {y_shape});
792}
793
794TEST_F(NaryGradTest, MaxMulti) {
795  TensorShape x_shape({1});
796  auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
797  auto all_same = Mul(scope_, Const(scope_, {1.f, 1.f, 1.f}), x);
798  auto y = Max(scope_, all_same, {0});
799  TensorShape y_shape({1});
800  RunTest({x}, {x_shape}, {y}, {y_shape});
801}
802
803TEST_F(NaryGradTest, AddN) {
804  TensorShape shape({3, 2, 5});
805  std::vector<Output> xs;
806  xs.push_back(Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)));
807  xs.push_back(Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)));
808  xs.push_back(Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)));
809  auto y = AddN(scope_, xs);
810  RunTest(xs, {shape, shape, shape}, {y}, {shape});
811}
812
813TEST_F(NaryGradTest, Add) {
814  TensorShape x1_shape({3, 2, 5});
815  TensorShape x2_shape({2, 5});
816  auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x1_shape));
817  auto x2 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x2_shape));
818  auto y = Add(scope_, x1, x2);
819  RunTest({x1, x2}, {x1_shape, x2_shape}, {y}, {x1_shape});
820}
821
822TEST_F(NaryGradTest, Sub) {
823  TensorShape x1_shape({3, 2, 5});
824  TensorShape x2_shape({2, 5});
825  auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x1_shape));
826  auto x2 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x2_shape));
827  auto y = Sub(scope_, x1, x2);
828  RunTest({x1, x2}, {x1_shape, x2_shape}, {y}, {x1_shape});
829}
830
831TEST_F(NaryGradTest, Mul) {
832  TensorShape x1_shape({3, 2, 5});
833  TensorShape x2_shape({2, 5});
834  auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x1_shape));
835  auto x2 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x2_shape));
836  auto y = Mul(scope_, x1, x2);
837  RunTest({x1, x2}, {x1_shape, x2_shape}, {y}, {x1_shape});
838}
839
840TEST_F(NaryGradTest, Div) {
841  TensorShape x_shape({3, 2, 5});
842  auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
843  // Test x / (1 + |x|) rather than x_1 / x_2 to avoid triggering large
844  // division errors in the numeric estimator used by the gradient checker.
845  auto y = Div(scope_, x, Add(scope_, Const<float>(scope_, 1), Abs(scope_, x)));
846  RunTest({x}, {x_shape}, {y}, {x_shape});
847}
848
849TEST_F(NaryGradTest, RealDiv) {
850  TensorShape x_shape({3, 2, 5});
851  auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
852  // Test x / (1 + |x|) rather than x_1 / x_2 to avoid triggering large
853  // division errors in the numeric estimator used by the gradient checker.
854  auto y =
855      RealDiv(scope_, x, Add(scope_, Const<float>(scope_, 1), Abs(scope_, x)));
856  RunTest({x}, {x_shape}, {y}, {x_shape});
857}
858
859TEST_F(NaryGradTest, SquaredDifference) {
860  TensorShape x1_shape({3, 2, 5});
861  TensorShape x2_shape({2, 5});
862  auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x1_shape));
863  auto x2 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x2_shape));
864  auto y = SquaredDifference(scope_, x1, x2);
865  RunTest({x1, x2}, {x1_shape, x2_shape}, {y}, {x1_shape});
866}
867
868TEST_F(NaryGradTest, Pow) {
869  TensorShape shape({3});
870  auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
871  // fix exponent to avoid overflow
872  auto y = Pow(scope_, x, Const(scope_, {1.f, 2.f, 3.f}));
873  RunTest({x}, {shape}, {y}, {shape});
874}
875
876TEST_F(NaryGradTest, Maximum) {
877  TensorShape shape({3, 2});
878  auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
879  auto y = Maximum(scope_, x, Const(scope_, 1.0f));
880  // Select values away from 1.0f to avoid instability when computing
881  // finite differences.
882  Tensor x_init_value =
883      test::AsTensor<float>({0.5f, 1.5f, -1.2f, 3.0f, 0.1f, 2.8f}, {3, 2});
884  RunTest(x, x_init_value, y, shape);
885}
886
887TEST_F(NaryGradTest, Minimum) {
888  TensorShape shape({3, 2});
889  auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
890  auto y = Minimum(scope_, x, Const(scope_, 1.0f));
891  // Select values away from 1.0f to avoid instability when computing
892  // finite differences.
893  Tensor x_init_value =
894      test::AsTensor<float>({0.5f, 1.5f, -1.2f, 3.0f, 0.1f, 2.8f}, {3, 2});
895  RunTest(x, x_init_value, y, shape);
896}
897
898TEST_F(NaryGradTest, Prod) {
899  TensorShape x_shape({2, 3, 2});
900  auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
901  auto y = Prod(scope_, x, {1});
902  // y's shape is the result of reducing x along axes 1
903  TensorShape y_shape({2, 1, 2});
904  RunTest({x}, {x_shape}, {y}, {y_shape});
905}
906
907}  // namespace
908}  // namespace tensorflow
909