math_grad_test.cc revision e4532d20973c4c00854492362665317551661c18
1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/cc/framework/grad_op_registry.h"
17#include "tensorflow/cc/framework/gradient_checker.h"
18#include "tensorflow/cc/framework/testutil.h"
19#include "tensorflow/cc/gradients/grad_testutil.h"
20#include "tensorflow/cc/ops/standard_ops.h"
21#include "tensorflow/core/framework/tensor_testutil.h"
22#include "tensorflow/core/lib/core/status_test_util.h"
23#include "tensorflow/core/lib/random/random.h"
24
25namespace tensorflow {
26using namespace ops;  // NOLINT(build/namespaces)
27
28namespace {
29
30// TODO(andydavis) Test gradient function against numeric gradients output.
31// TODO(andydavis) As more gradients are added move common test functions
32// to a testutil library.
33
34class CWiseUnaryGradTest : public ::testing::Test {
35 protected:
36  CWiseUnaryGradTest() : scope_(Scope::NewRootScope().WithDevice("/cpu:0")) {}
37
38  enum UnaryOpType {
39    ABS,
40    NEG,
41    INV,
42    SQUARE,
43    SQRT,
44    RSQRT,
45    EXP,
46    EXPM1,
47    LOG,
48    LOG1P,
49    SINH,
50    COSH,
51    TANH,
52    ASINH,
53    ACOSH,
54    ATANH,
55    SIGMOID,
56    SIGN,
57    SIN,
58    COS,
59    ASIN,
60    ACOS,
61    TAN,
62    ATAN,
63    REAL,
64    IMAG,
65    CONJ,
66    COMPLEX,
67    ANGLE,
68    LGAMMA,
69    ERF
70  };
71
72  template <typename X_T, typename Y_T>
73  void TestCWiseGrad(UnaryOpType op_type, const std::function<X_T(int)>& x_fn) {
74    TF_ASSERT_OK(scope_.status());
75    DataType x_type = DataTypeToEnum<X_T>::v();
76    TensorShape shape({2, 3, 2});
77    auto x = Placeholder(scope_, x_type, Placeholder::Shape(shape));
78    Tensor x_data(x_type, shape);
79    auto x_data_flat = x_data.flat<X_T>();
80    for (int i = 0; i < x_data_flat.size(); ++i) {
81      x_data_flat(i) = x_fn(i);
82    }
83
84    Output y;
85    switch (op_type) {
86      case ABS:
87        y = Abs(scope_, x);
88        break;
89      case NEG:
90        y = Neg(scope_, x);
91        break;
92      case INV:
93        y = Reciprocal(scope_, x);
94        break;
95      case SQUARE:
96        y = Square(scope_, x);
97        break;
98      case SQRT:
99        y = Sqrt(scope_, x);
100        break;
101      case RSQRT:
102        y = Rsqrt(scope_, x);
103        break;
104      case EXP:
105        y = Exp(scope_, x);
106        break;
107      case EXPM1:
108        y = Expm1(scope_, x);
109        break;
110      case LOG:
111        y = Log(scope_, x);
112        break;
113      case LOG1P:
114        y = Log1p(scope_, x);
115        break;
116      case SINH:
117        y = Sinh(scope_, x);
118        break;
119      case COSH:
120        y = Cosh(scope_, x);
121        break;
122      case TANH:
123        y = Tanh(scope_, x);
124        break;
125      case ASINH:
126        y = Asinh(scope_, x);
127        break;
128      case ACOSH:
129        y = Acosh(scope_, x);
130        break;
131      case ATANH:
132        y = Atanh(scope_, x);
133        break;
134      case SIGMOID:
135        y = Sigmoid(scope_, x);
136        break;
137      case SIGN:
138        y = Sign(scope_, x);
139        break;
140      case SIN:
141        y = Sin(scope_, x);
142        break;
143      case COS:
144        y = Cos(scope_, x);
145        break;
146      case ASIN:
147        y = Asin(scope_, x);
148        break;
149      case ACOS:
150        y = Acos(scope_, x);
151        break;
152      case TAN:
153        y = Tan(scope_, x);
154        break;
155      case ATAN:
156        y = Atan(scope_, x);
157        break;
158      case REAL:
159        y = Real(scope_, x);
160        break;
161      case IMAG:
162        y = Imag(scope_, x);
163        break;
164      case CONJ:
165        y = Conj(scope_, x);
166        break;
167      case COMPLEX:
168        y = Complex(scope_, x, x);
169        break;
170      case ANGLE:
171        y = Angle(scope_, x);
172        break;
173      case LGAMMA:
174        y = Lgamma(scope_, x);
175        break;
176      case ERF:
177        y = Erf(scope_, x);
178        break;
179    }
180
181    float max_error;
182    TF_ASSERT_OK((ComputeGradientError<X_T, Y_T, float>(scope_, x, x_data, y,
183                                                        shape, &max_error)));
184    EXPECT_LT(max_error, 1e-3f);
185  }
186
187  float RV(const std::vector<float>& v) {
188    return v[random::New64() % v.size()];
189  }
190
191  complex64 CRV(const std::vector<complex64>& v) {
192    return v[random::New64() % v.size()];
193  }
194
195  complex64 conjugate(const complex64& val) {
196    return complex64(val.real(), -val.imag());
197  }
198
199  Scope scope_;
200};
201
202TEST_F(CWiseUnaryGradTest, Abs) {
203  auto x_fn = [this](const int i) { return RV({-1, 0, 1}); };
204  TestCWiseGrad<float, float>(ABS, x_fn);
205}
206
207TEST_F(CWiseUnaryGradTest, Neg) {
208  auto x_fn = [this](const int i) { return RV({-1, 0, 1}); };
209  TestCWiseGrad<float, float>(NEG, x_fn);
210}
211
212TEST_F(CWiseUnaryGradTest, Reciprocal) {
213  auto x_fn = [this](const int i) { return RV({-1, 1, -2, 2, -3, 3, -4, 4}); };
214  TestCWiseGrad<float, float>(INV, x_fn);
215}
216
217TEST_F(CWiseUnaryGradTest, Reciprocal_Complex) {
218  auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); };
219  TestCWiseGrad<complex64, complex64>(INV, x_fn);
220}
221
222TEST_F(CWiseUnaryGradTest, Square) {
223  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
224  TestCWiseGrad<float, float>(SQUARE, x_fn);
225}
226
227TEST_F(CWiseUnaryGradTest, Square_Complex) {
228  auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); };
229  TestCWiseGrad<complex64, complex64>(SQUARE, x_fn);
230}
231
232TEST_F(CWiseUnaryGradTest, Sqrt) {
233  auto x_fn = [this](const int i) { return RV({0.5, 1, 2, 3, 4, 5, 6, 7}); };
234  TestCWiseGrad<float, float>(SQRT, x_fn);
235}
236
237TEST_F(CWiseUnaryGradTest, Sqrt_Complex) {
238  auto x_fn = [this](const int i) {
239    return CRV({{-1.0f, 0.5f}, {1.0f, 0.5f}, {2, -1}});
240  };
241  TestCWiseGrad<complex64, complex64>(SQRT, x_fn);
242}
243
244TEST_F(CWiseUnaryGradTest, Rsqrt) {
245  auto x_fn = [this](const int i) { return RV({1, 2, 3, 4, 5, 6, 7, 8}); };
246  TestCWiseGrad<float, float>(RSQRT, x_fn);
247}
248
249TEST_F(CWiseUnaryGradTest, Rsqrt_Complex) {
250  auto x_fn = [this](const int i) {
251    return CRV({{-1.0f, 0.5f}, {1.0f, 0.5f}, {2, -1}});
252  };
253  TestCWiseGrad<complex64, complex64>(RSQRT, x_fn);
254}
255
256TEST_F(CWiseUnaryGradTest, Exp) {
257  auto x_fn = [this](const int i) {
258    return RV({0, -1, 1, -1.5f, 1.5f, -2, 2});
259  };
260  TestCWiseGrad<float, float>(EXP, x_fn);
261}
262
263TEST_F(CWiseUnaryGradTest, Exp_Complex) {
264  auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); };
265  TestCWiseGrad<complex64, complex64>(EXP, x_fn);
266}
267
268TEST_F(CWiseUnaryGradTest, Expm1) {
269  auto x_fn = [this](const int i) { return RV({0, -1, 1e-6, 1, -1.5, 1.5}); };
270  TestCWiseGrad<float, float>(EXPM1, x_fn);
271}
272
273TEST_F(CWiseUnaryGradTest, Expm1_Complex) {
274  auto x_fn = [this](const int i) {
275    return CRV({{-1, 0}, {1, 0}, {1.5, -1.5}});
276  };
277  TestCWiseGrad<complex64, complex64>(EXPM1, x_fn);
278}
279
280TEST_F(CWiseUnaryGradTest, Log) {
281  auto x_fn = [this](const int i) { return RV({0.5, 1, 2, 3, 4}); };
282  TestCWiseGrad<float, float>(LOG, x_fn);
283}
284
285TEST_F(CWiseUnaryGradTest, Log_Complex) {
286  auto x_fn = [this](const int i) {
287    return CRV({{-1, 0.5f}, {1, 0.5f}, {2, -1}});
288  };
289  TestCWiseGrad<complex64, complex64>(LOG, x_fn);
290}
291
292TEST_F(CWiseUnaryGradTest, Log1p) {
293  auto x_fn = [this](const int i) { return RV({0, 1e-6, 1, 2, 3, 4, 100}); };
294  TestCWiseGrad<float, float>(LOG1P, x_fn);
295}
296
297TEST_F(CWiseUnaryGradTest, Log1p_Complex) {
298  auto x_fn = [this](const int i) {
299    return CRV({{0, 0}, {1e-6, 0}, {2, -1}, {1, 2}, {3, 4}});
300  };
301  TestCWiseGrad<complex64, complex64>(LOG1P, x_fn);
302}
303
304TEST_F(CWiseUnaryGradTest, Sinh) {
305  auto x_fn = [this](const int i) { return RV({0.5, -0.5, 1, -1, 1.5, -1.5}); };
306  TestCWiseGrad<float, float>(SINH, x_fn);
307}
308
309TEST_F(CWiseUnaryGradTest, Sinh_Complex) {
310  auto x_fn = [this](const int i) {
311    return CRV({{0.5, 0.25}, {0.25, 0.5}, {1.5, -1}, {1, 1.5}});
312  };
313  TestCWiseGrad<complex64, complex64>(SINH, x_fn);
314}
315
316TEST_F(CWiseUnaryGradTest, Cosh) {
317  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
318  TestCWiseGrad<float, float>(COSH, x_fn);
319}
320
321TEST_F(CWiseUnaryGradTest, Cosh_Complex) {
322  auto x_fn = [this](const int i) {
323    return CRV({{0.5, 0.25}, {0.25, 0.5}, {1.5, -1}, {1, 1.5}});
324  };
325  TestCWiseGrad<complex64, complex64>(COSH, x_fn);
326}
327
328TEST_F(CWiseUnaryGradTest, Tanh) {
329  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
330  TestCWiseGrad<float, float>(TANH, x_fn);
331}
332
333TEST_F(CWiseUnaryGradTest, Tanh_Complex) {
334  auto x_fn = [this](const int i) {
335    return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
336  };
337  TestCWiseGrad<complex64, complex64>(TANH, x_fn);
338}
339
340TEST_F(CWiseUnaryGradTest, Asinh) {
341  auto x_fn = [this](const int i) { return RV({0.5, 1, -1, -1.5, 1.5}); };
342  TestCWiseGrad<float, float>(ASINH, x_fn);
343}
344
345TEST_F(CWiseUnaryGradTest, Asinh_Complex) {
346  auto x_fn = [this](const int i) {
347    return CRV({{1, 0.5}, {0.5, 1}, {0.5, -1}, {1, 1.5}});
348  };
349  TestCWiseGrad<complex64, complex64>(ASINH, x_fn);
350}
351
352TEST_F(CWiseUnaryGradTest, Acosh) {
353  auto x_fn = [this](const int i) { return RV({1.5, 2, 2.5}); };
354  TestCWiseGrad<float, float>(ACOSH, x_fn);
355}
356
357TEST_F(CWiseUnaryGradTest, Acosh_Complex) {
358  auto x_fn = [this](const int i) {
359    return CRV({{1, 0.5}, {0.5, 1}, {0.5, -1}, {1, 1.5}});
360  };
361  TestCWiseGrad<complex64, complex64>(ACOSH, x_fn);
362}
363
364TEST_F(CWiseUnaryGradTest, Atanh) {
365  auto x_fn = [this](const int i) { return RV({0, -0.5, 0.5, -0.1, 0.1}); };
366  TestCWiseGrad<float, float>(ATANH, x_fn);
367}
368
369TEST_F(CWiseUnaryGradTest, Atanh_Complex) {
370  auto x_fn = [this](const int i) {
371    return CRV({{0.1, 0}, {0, 0.1}, {0.2, -0.1}, {0.1, 0.2}, {0.3, 0.4}});
372  };
373  TestCWiseGrad<complex64, complex64>(ATANH, x_fn);
374}
375
376TEST_F(CWiseUnaryGradTest, Sigmoid) {
377  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
378  TestCWiseGrad<float, float>(SIGMOID, x_fn);
379}
380
381TEST_F(CWiseUnaryGradTest, Sigmoid_Complex) {
382  auto x_fn = [this](const int i) {
383    return CRV({{1, 0}, {0, 0}, {2, -1}, {1, 2}, {3, 4}});
384  };
385  TestCWiseGrad<complex64, complex64>(SIGMOID, x_fn);
386}
387
388TEST_F(CWiseUnaryGradTest, Sign) {
389  auto x_fn = [this](const int i) { return RV({-1, 1, -2, 2, -3, 3}); };
390  TestCWiseGrad<float, float>(SIGN, x_fn);
391}
392
393TEST_F(CWiseUnaryGradTest, Sin) {
394  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
395  TestCWiseGrad<float, float>(SIN, x_fn);
396}
397
398TEST_F(CWiseUnaryGradTest, Sin_Complex) {
399  auto x_fn = [this](const int i) {
400    return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}});
401  };
402  TestCWiseGrad<complex64, complex64>(SIN, x_fn);
403}
404
405TEST_F(CWiseUnaryGradTest, Cos) {
406  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
407  TestCWiseGrad<float, float>(COS, x_fn);
408}
409
410TEST_F(CWiseUnaryGradTest, Cos_Complex) {
411  auto x_fn = [this](const int i) {
412    return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}});
413  };
414  TestCWiseGrad<complex64, complex64>(COS, x_fn);
415}
416
417TEST_F(CWiseUnaryGradTest, Asin) {
418  auto x_fn = [this](const int i) { return RV({0, 0.25, -0.25, -0.5, 0.5}); };
419  TestCWiseGrad<float, float>(ASIN, x_fn);
420}
421
422TEST_F(CWiseUnaryGradTest, Asin_Complex) {
423  auto x_fn = [this](const int i) {
424    return CRV({{0.5, 0}, {0, 0.5}, {0.25, -0.75}, {0.5, 0.25}});
425  };
426  // TODO(kbsriram)
427  // Enable test when the asin kernel supports complex numbers
428  if (false) {
429    TestCWiseGrad<complex64, complex64>(ASIN, x_fn);
430  }
431}
432
433TEST_F(CWiseUnaryGradTest, Acos) {
434  auto x_fn = [this](const int i) { return RV({0, -0.5, 0.5, -0.75, 0.75}); };
435  TestCWiseGrad<float, float>(ACOS, x_fn);
436}
437
438TEST_F(CWiseUnaryGradTest, Acos_Complex) {
439  auto x_fn = [this](const int i) {
440    return CRV({{0.5, 0}, {0, 0.5}, {0.25, -0.75}, {0.5, 0.25}});
441  };
442  // TODO(kbsriram)
443  // Add test when the acos kernel supports complex numbers
444  if (false) {
445    TestCWiseGrad<complex64, complex64>(ACOS, x_fn);
446  }
447}
448
449TEST_F(CWiseUnaryGradTest, Tan) {
450  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
451  TestCWiseGrad<float, float>(TAN, x_fn);
452}
453
454TEST_F(CWiseUnaryGradTest, Tan_Complex) {
455  auto x_fn = [this](const int i) {
456    return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
457  };
458  // TODO(kbsriram)
459  // Enable when tan kernel supports complex inputs
460  if (false) {
461    TestCWiseGrad<complex64, complex64>(TAN, x_fn);
462  }
463}
464
465TEST_F(CWiseUnaryGradTest, Atan) {
466  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
467  TestCWiseGrad<float, float>(ATAN, x_fn);
468}
469
470TEST_F(CWiseUnaryGradTest, Atan_Complex) {
471  auto x_fn = [this](const int i) {
472    return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
473  };
474  // TODO(kbsriram)
475  // Add test when the atan kernel supports complex numbers
476  if (false) {
477    TestCWiseGrad<complex64, complex64>(ATAN, x_fn);
478  }
479}
480
481TEST_F(CWiseUnaryGradTest, Real) {
482  auto x_fn = [this](const int i) {
483    return CRV({{1, -1}, {-2, 2}, {2, 3}, {-2, -3}});
484  };
485  TestCWiseGrad<complex64, float>(REAL, x_fn);
486}
487
488TEST_F(CWiseUnaryGradTest, Imag) {
489  auto x_fn = [this](const int i) {
490    return CRV({{1, -1}, {-2, 2}, {2, 3}, {-2, -3}});
491  };
492  TestCWiseGrad<complex64, float>(IMAG, x_fn);
493}
494
495TEST_F(CWiseUnaryGradTest, Conj) {
496  auto x_fn = [this](const int i) {
497    return CRV({{1, -1}, {-2, 2}, {2, 3}, {-2, -3}});
498  };
499  TestCWiseGrad<complex64, complex64>(CONJ, x_fn);
500}
501
502TEST_F(CWiseUnaryGradTest, Complex) {
503  auto x_fn = [this](const int i) { return RV({1, -1, 2, -2, 3, -3}); };
504  TestCWiseGrad<float, complex64>(COMPLEX, x_fn);
505}
506
507TEST_F(CWiseUnaryGradTest, Angle) {
508  auto x_fn = [this](const int i) {
509    return CRV({{1.5, 1.5}, {1.5, -1.5}, {-1.5, 1.5}, {-1.5, -1.5}});
510  };
511  TestCWiseGrad<complex64, float>(ANGLE, x_fn);
512}
513
514TEST_F(CWiseUnaryGradTest, Lgamma) {
515  auto x_fn = [this](const int i) {
516    return RV({-3.5, -2.5, -1.5, 1.0, 2.0, 3.5});
517  };
518  TestCWiseGrad<float, float>(LGAMMA, x_fn);
519}
520
521TEST_F(CWiseUnaryGradTest, Lgamma_Complex) {
522  auto x_fn = [this](const int i) {
523    return CRV({{-3.5, 0.5}, {-1.5, -0.5}, {1.5, -1.0}, {3.5, 1.0}});
524  };
525  // TODO(kbsriram)
526  // Add test when the lgamma kernel supports complex numbers
527  if (false) {
528    TestCWiseGrad<complex64, complex64>(LGAMMA, x_fn);
529  }
530}
531
532TEST_F(CWiseUnaryGradTest, Erf) {
533  auto x_fn = [this](const int i) {
534    return RV({-1.2, -1.0, -0.5, 0.3, 0.5, 1.3});
535  };
536  TestCWiseGrad<float, float>(ERF, x_fn);
537}
538
539TEST_F(CWiseUnaryGradTest, Erf_Complex) {
540  auto x_fn = [this](const int i) {
541    return CRV({{-1.2, 0.5}, {-0.5, -0.5}, {0.5, 0.5}, {1.2, -0.5}});
542  };
543  // TODO(kbsriram)
544  // Add test when the erf kernel supports complex numbers
545  if (false) {
546    TestCWiseGrad<complex64, complex64>(ERF, x_fn);
547  }
548}
549
550class MathGradTest : public ::testing::Test {
551 protected:
552  MathGradTest() : root_(Scope::NewRootScope().WithDevice("/cpu:0")) {}
553
554  template <typename T>
555  void TestMatMulGrad(const bool is_batch, const bool t_x, const bool t_y) {
556    TF_ASSERT_OK(root_.status());
557    // Generate random (but compatible) shapes for matrix multiplication.
558    std::vector<TensorShape> shapes;
559    RandMatMulShapes(is_batch, t_x, t_y, &shapes);
560    TensorShape x_shape = shapes[0];
561    TensorShape y_shape = shapes[1];
562    TensorShape z_shape = shapes[2];
563    auto x =
564        Placeholder(root_, DataTypeToEnum<T>::v(), Placeholder::Shape(x_shape));
565    auto y =
566        Placeholder(root_, DataTypeToEnum<T>::v(), Placeholder::Shape(y_shape));
567    Output z;
568    if (is_batch) {
569      z = BatchMatMul(root_, x, y, BatchMatMul::AdjX(t_x).AdjY(t_y));
570    } else {
571      z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y));
572    }
573
574    float max_error;
575    TF_ASSERT_OK((ComputeGradientError<T, T, float>(
576        root_, {x, y}, {x_shape, y_shape}, {z}, {z_shape}, &max_error)));
577    EXPECT_LT(max_error, 1e-3);
578  }
579
580  void RandMatMulShapes(const bool is_batch, const bool tx, const bool ty,
581                        std::vector<TensorShape>* shapes) {
582    // Choose a random batch size in [1, 4]
583    const int b = 1 + (random::New64() % 4);
584    // z = MatMul(x, y)
585    const int m = Rand();
586    const int k = Rand();
587    const int n = Rand();
588
589    TensorShape x_shape;
590    if (is_batch) {
591      // x.shape = [b, m, k]
592      x_shape = tx ? TensorShape({b, k, m}) : TensorShape({b, m, k});
593    } else {
594      // x.shape = [m, k]
595      x_shape = tx ? TensorShape({k, m}) : TensorShape({m, k});
596    }
597    shapes->push_back(x_shape);
598
599    TensorShape y_shape;
600    if (is_batch) {
601      // y.shape = [b, k, n]
602      y_shape = ty ? TensorShape({b, n, k}) : TensorShape({b, k, n});
603    } else {
604      // y.shape = [k, n]
605      y_shape = ty ? TensorShape({n, k}) : TensorShape({k, n});
606    }
607    shapes->push_back(y_shape);
608
609    TensorShape z_shape;
610    if (is_batch) {
611      // z.shape = [b, m, n]
612      z_shape = TensorShape({b, m, n});
613    } else {
614      // z.shape = [m, n]
615      z_shape = TensorShape({m, n});
616    }
617    shapes->push_back(z_shape);
618  }
619
620  int Rand() { return 1 + (random::New64() % 10); }
621
622  Scope root_;
623};
624
625TEST_F(MathGradTest, MatMulGrad_NoTranspose) {
626  TestMatMulGrad<float>(false, false, false);
627}
628
629TEST_F(MathGradTest, MatMulComplexGrad_NoTranspose) {
630  TestMatMulGrad<complex64>(false, false, false);
631}
632
633TEST_F(MathGradTest, MatMulGrad_TransposeX) {
634  TestMatMulGrad<float>(false, true, false);
635}
636
637TEST_F(MathGradTest, MatMulComplexGrad_TransposeX) {
638  TestMatMulGrad<complex64>(false, true, false);
639}
640
641TEST_F(MathGradTest, MatMulGrad_TransposeY) {
642  TestMatMulGrad<float>(false, false, true);
643}
644
645TEST_F(MathGradTest, MatMulComplexGrad_TransposeY) {
646  TestMatMulGrad<complex64>(false, false, true);
647}
648
649TEST_F(MathGradTest, MatMulGrad_TransposeX_TransposeY) {
650  TestMatMulGrad<float>(false, true, true);
651}
652
653TEST_F(MathGradTest, MatMulComplexGrad_TransposeX_TransposeY) {
654  TestMatMulGrad<complex64>(false, true, true);
655}
656
657TEST_F(MathGradTest, BatchMatMulGrad_NoTranspose) {
658  TestMatMulGrad<float>(true, false, false);
659}
660
661TEST_F(MathGradTest, BatchMatMulComplexGrad_NoTranspose) {
662  TestMatMulGrad<complex64>(true, false, false);
663}
664
665TEST_F(MathGradTest, BatchMatMulGrad_TransposeX) {
666  TestMatMulGrad<float>(true, true, false);
667}
668
669TEST_F(MathGradTest, BatchMatMulComplexGrad_TransposeX) {
670  TestMatMulGrad<complex64>(true, true, false);
671}
672
673TEST_F(MathGradTest, BatchMatMulGrad_TransposeY) {
674  TestMatMulGrad<float>(true, false, true);
675}
676
677TEST_F(MathGradTest, BatchMatMulComplexGrad_TransposeY) {
678  TestMatMulGrad<complex64>(true, false, true);
679}
680
681TEST_F(MathGradTest, BatchMatMulGrad_TransposeX_TransposeY) {
682  TestMatMulGrad<float>(true, true, true);
683}
684
685TEST_F(MathGradTest, BatchMatMulComplexGrad_TransposeX_TransposeY) {
686  TestMatMulGrad<complex64>(true, true, true);
687}
688
689class NaryGradTest : public ::testing::Test {
690 protected:
691  NaryGradTest() : scope_(Scope::NewRootScope().WithDevice("/cpu:0")) {}
692
693  void RunTest(const OutputList& xs, const std::vector<TensorShape>& x_shapes,
694               const OutputList& ys, const std::vector<TensorShape>& y_shapes) {
695    TF_ASSERT_OK(scope_.status());
696    float max_error;
697    TF_ASSERT_OK((ComputeGradientError<float, float, float>(
698        scope_, xs, x_shapes, ys, y_shapes, &max_error)));
699    EXPECT_LT(max_error, 1e-3);
700  }
701
702  void RunTest(const Output& x, const Tensor& x_init_value, const Output& y,
703               const TensorShape& y_shape) {
704    TF_ASSERT_OK(scope_.status());
705    float max_error;
706    TF_ASSERT_OK((ComputeGradientError<float, float, float>(
707        scope_, x, x_init_value, y, y_shape, &max_error)));
708    EXPECT_LT(max_error, 1e-3);
709  }
710
711  Scope scope_;
712};
713
714TEST_F(NaryGradTest, Sum) {
715  TensorShape x_shape({2, 3, 5, 7});
716  auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
717  auto y = Sum(scope_, x, {1, -1});
718  // y's shape is the result of reducing x along axes 1 and -1 (= 3)
719  TensorShape y_shape({2, 5});
720  RunTest({x}, {x_shape}, {y}, {y_shape});
721}
722
723TEST_F(NaryGradTest, Mean) {
724  TensorShape x_shape({2, 3, 5, 7});
725  auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
726  auto y = Mean(scope_, x, {1, -1});
727  // y's shape is the result of reducing x along axes 1 and -1 (= 3)
728  TensorShape y_shape({2, 5});
729  RunTest({x}, {x_shape}, {y}, {y_shape});
730}
731
732TEST_F(NaryGradTest, Min) {
733  TensorShape x_shape({2, 3});
734  auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
735  auto y = Min(scope_, x, {-1});
736  // y's shape is the result of reducing x along axes -1 (= 1)
737  TensorShape y_shape({2});
738  Tensor x_init_value =
739      test::AsTensor<float>({0.5f, 0.7f, 0.2f, 1.0f, 1.5f, -2.8f}, x_shape);
740  RunTest(x, x_init_value, y, y_shape);
741}
742
743TEST_F(NaryGradTest, Max) {
744  TensorShape x_shape({2, 3});
745  auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
746  auto y = Max(scope_, x, {-1});
747  // y's shape is the result of reducing x along axes -1 (= 1)
748  TensorShape y_shape({2});
749  Tensor x_init_value =
750      test::AsTensor<float>({0.5f, 0.7f, 0.2f, 1.0f, 1.5f, -2.8f}, x_shape);
751  RunTest(x, x_init_value, y, y_shape);
752}
753
754TEST_F(NaryGradTest, MinMulti) {
755  // Test gradient when there are multiple minima.
756  // Note that we cannot directly use a test Tensor with multiple
757  // minima, as the numeric estimator will calculate incorrect
758  // gradients when perturbing each entry in the Tensor (which then
759  // changes how many minima exist.)
760  // Instead, we use a single input that broadcast-multiplies a larger
761  // tensor with equal values, and apply reduce_min to the multiplied
762  // result.
763  TensorShape x_shape({1});
764  auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
765  auto all_same = Mul(scope_, Const(scope_, {1.f, 1.f, 1.f}), x);
766  auto y = Min(scope_, all_same, {0});
767  // y is a [3] shaped tensor reduced along dimension 0, so it is [1] shaped
768  TensorShape y_shape({1});
769  RunTest({x}, {x_shape}, {y}, {y_shape});
770}
771
772TEST_F(NaryGradTest, MaxMulti) {
773  TensorShape x_shape({1});
774  auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
775  auto all_same = Mul(scope_, Const(scope_, {1.f, 1.f, 1.f}), x);
776  auto y = Max(scope_, all_same, {0});
777  TensorShape y_shape({1});
778  RunTest({x}, {x_shape}, {y}, {y_shape});
779}
780
781TEST_F(NaryGradTest, AddN) {
782  TensorShape shape({3, 2, 5});
783  std::vector<Output> xs;
784  xs.push_back(Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)));
785  xs.push_back(Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)));
786  xs.push_back(Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)));
787  auto y = AddN(scope_, xs);
788  RunTest(xs, {shape, shape, shape}, {y}, {shape});
789}
790
791TEST_F(NaryGradTest, Add) {
792  TensorShape x1_shape({3, 2, 5});
793  TensorShape x2_shape({2, 5});
794  auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x1_shape));
795  auto x2 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x2_shape));
796  auto y = Add(scope_, x1, x2);
797  RunTest({x1, x2}, {x1_shape, x2_shape}, {y}, {x1_shape});
798}
799
800TEST_F(NaryGradTest, Sub) {
801  TensorShape x1_shape({3, 2, 5});
802  TensorShape x2_shape({2, 5});
803  auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x1_shape));
804  auto x2 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x2_shape));
805  auto y = Sub(scope_, x1, x2);
806  RunTest({x1, x2}, {x1_shape, x2_shape}, {y}, {x1_shape});
807}
808
809TEST_F(NaryGradTest, Mul) {
810  TensorShape x1_shape({3, 2, 5});
811  TensorShape x2_shape({2, 5});
812  auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x1_shape));
813  auto x2 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x2_shape));
814  auto y = Mul(scope_, x1, x2);
815  RunTest({x1, x2}, {x1_shape, x2_shape}, {y}, {x1_shape});
816}
817
818TEST_F(NaryGradTest, Div) {
819  TensorShape x_shape({3, 2, 5});
820  auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
821  // Test x / (1 + |x|) rather than x_1 / x_2 to avoid triggering large
822  // division errors in the numeric estimator used by the gradient checker.
823  auto y = Div(scope_, x, Add(scope_, Const<float>(scope_, 1), Abs(scope_, x)));
824  RunTest({x}, {x_shape}, {y}, {x_shape});
825}
826
827TEST_F(NaryGradTest, RealDiv) {
828  TensorShape x_shape({3, 2, 5});
829  auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
830  // Test x / (1 + |x|) rather than x_1 / x_2 to avoid triggering large
831  // division errors in the numeric estimator used by the gradient checker.
832  auto y =
833      RealDiv(scope_, x, Add(scope_, Const<float>(scope_, 1), Abs(scope_, x)));
834  RunTest({x}, {x_shape}, {y}, {x_shape});
835}
836
837TEST_F(NaryGradTest, SquaredDifference) {
838  TensorShape x1_shape({3, 2, 5});
839  TensorShape x2_shape({2, 5});
840  auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x1_shape));
841  auto x2 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x2_shape));
842  auto y = SquaredDifference(scope_, x1, x2);
843  RunTest({x1, x2}, {x1_shape, x2_shape}, {y}, {x1_shape});
844}
845
846TEST_F(NaryGradTest, Maximum) {
847  TensorShape shape({3, 2});
848  auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
849  auto y = Maximum(scope_, x, Const(scope_, 1.0f));
850  // Select values away from 1.0f to avoid instability when computing
851  // finite differences.
852  Tensor x_init_value =
853      test::AsTensor<float>({0.5f, 1.5f, -1.2f, 3.0f, 0.1f, 2.8f}, {3, 2});
854  RunTest(x, x_init_value, y, shape);
855}
856
857TEST_F(NaryGradTest, Minimum) {
858  TensorShape shape({3, 2});
859  auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
860  auto y = Minimum(scope_, x, Const(scope_, 1.0f));
861  // Select values away from 1.0f to avoid instability when computing
862  // finite differences.
863  Tensor x_init_value =
864      test::AsTensor<float>({0.5f, 1.5f, -1.2f, 3.0f, 0.1f, 2.8f}, {3, 2});
865  RunTest(x, x_init_value, y, shape);
866}
867
868TEST_F(NaryGradTest, Select) {
869  TensorShape shape({3, 4});
870  auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
871  auto x2 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
872  auto y = Where3(scope_, Greater(scope_, x1, x2), x1, x2);
873  RunTest({x1, x2}, {shape, shape}, {y}, {shape});
874}
875
876}  // namespace
877}  // namespace tensorflow
878