math_grad_test.cc revision e6b011763a60d239972c8c6c0f36536ab6f885a3
1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/cc/framework/grad_op_registry.h"
17#include "tensorflow/cc/framework/gradient_checker.h"
18#include "tensorflow/cc/framework/testutil.h"
19#include "tensorflow/cc/gradients/grad_testutil.h"
20#include "tensorflow/cc/ops/standard_ops.h"
21#include "tensorflow/core/framework/tensor_testutil.h"
22#include "tensorflow/core/lib/core/status_test_util.h"
23#include "tensorflow/core/lib/random/random.h"
24
25namespace tensorflow {
26using namespace ops;  // NOLINT(build/namespaces)
27
28namespace {
29
30// TODO(andydavis) Test gradient function against numeric gradients output.
31// TODO(andydavis) As more gradients are added move common test functions
32// to a testutil library.
33
34class CWiseUnaryGradTest : public ::testing::Test {
35 protected:
36  CWiseUnaryGradTest() : scope_(Scope::NewRootScope().WithDevice("/cpu:0")) {}
37
38  enum UnaryOpType {
39    ABS,
40    NEG,
41    INV,
42    SQUARE,
43    SQRT,
44    RSQRT,
45    EXP,
46    EXPM1,
47    LOG,
48    LOG1P,
49    SINH,
50    COSH,
51    TANH,
52    ASINH,
53    ACOSH,
54    ATANH,
55    SIGMOID,
56    SIGN,
57    SIN,
58    COS,
59    ASIN,
60    ACOS,
61    TAN,
62    ATAN
63  };
64
65  template <typename T>
66  void TestCWiseGrad(UnaryOpType op_type, const std::function<T(int)>& x_fn,
67                     const std::function<T(const T&)>& dy_fn,
68                     const std::function<T(const T&, const T&)>& dx_fn) {
69    DataType dtype = DataTypeToEnum<T>::v();
70    Tensor x(dtype, {2, 3, 2});
71    auto x_flat = x.flat<T>();
72    for (int i = 0; i < x_flat.size(); ++i) {
73      x_flat(i) = x_fn(i);
74    }
75
76    Tensor dy(dtype, {2, 3, 2});
77    auto dy_flat = dy.flat<T>();
78    for (int i = 0; i < dy_flat.size(); ++i) {
79      dy_flat(i) = dy_fn(x_flat(i));
80    }
81
82    Tensor dx(dtype, {2, 3, 2});
83    auto dx_flat = dx.flat<T>();
84    for (int i = 0; i < dx_flat.size(); ++i) {
85      dx_flat(i) = dx_fn(x_flat(i), dy_flat(i));
86    }
87
88    Output y;
89    switch (op_type) {
90      case ABS:
91        y = Abs(scope_, x);
92        break;
93      case NEG:
94        y = Neg(scope_, x);
95        break;
96      case INV:
97        y = Reciprocal(scope_, x);
98        break;
99      case SQUARE:
100        y = Square(scope_, x);
101        break;
102      case SQRT:
103        y = Sqrt(scope_, x);
104        break;
105      case RSQRT:
106        y = Rsqrt(scope_, x);
107        break;
108      case EXP:
109        y = Exp(scope_, x);
110        break;
111      case EXPM1:
112        y = Expm1(scope_, x);
113        break;
114      case LOG:
115        y = Log(scope_, x);
116        break;
117      case LOG1P:
118        y = Log1p(scope_, x);
119        break;
120      case SINH:
121        y = Sinh(scope_, x);
122        break;
123      case COSH:
124        y = Cosh(scope_, x);
125        break;
126      case TANH:
127        y = Tanh(scope_, x);
128        break;
129      case ASINH:
130        y = Asinh(scope_, x);
131        break;
132      case ACOSH:
133        y = Acosh(scope_, x);
134        break;
135      case ATANH:
136        y = Atanh(scope_, x);
137        break;
138      case SIGMOID:
139        y = Sigmoid(scope_, x);
140        break;
141      case SIGN:
142        y = Sign(scope_, x);
143        break;
144      case SIN:
145        y = Sin(scope_, x);
146        break;
147      case COS:
148        y = Cos(scope_, x);
149        break;
150      case ASIN:
151        y = Asin(scope_, x);
152        break;
153      case ACOS:
154        y = Acos(scope_, x);
155        break;
156      case TAN:
157        y = Tan(scope_, x);
158        break;
159      case ATAN:
160        y = Atan(scope_, x);
161        break;
162    }
163
164    std::vector<Output> grad_outputs;
165    TF_ASSERT_OK(test::CallGradFunction(
166        scope_, Operation(y.node()), {ops::Const(scope_, dy)}, &grad_outputs));
167    Tensor output;
168    test::GetTensor(scope_, grad_outputs[0], &output);
169    test::ExpectClose(output, dx);
170  }
171
172  float RV(const std::vector<float>& v) {
173    return v[random::New64() % v.size()];
174  }
175
176  complex64 CRV(const std::vector<complex64>& v) {
177    return v[random::New64() % v.size()];
178  }
179
180  complex64 conjugate(const complex64& val) {
181    return complex64(val.real(), -val.imag());
182  }
183
184  const complex64 one_{1.0, 0};
185
186  Scope scope_;
187};
188
189TEST_F(CWiseUnaryGradTest, Abs) {
190  auto x_fn = [this](const int i) { return RV({-1, 0, 1}); };
191  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
192  auto dx_fn = [this](const float x, const float dy) { return x * dy; };
193  TestCWiseGrad<float>(ABS, x_fn, dy_fn, dx_fn);
194}
195
196TEST_F(CWiseUnaryGradTest, Neg) {
197  auto x_fn = [this](const int i) { return RV({-1, 0, 1}); };
198  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
199  auto dx_fn = [this](const float x, const float dy) { return -dy; };
200  TestCWiseGrad<float>(NEG, x_fn, dy_fn, dx_fn);
201}
202
203TEST_F(CWiseUnaryGradTest, Reciprocal) {
204  auto x_fn = [this](const int i) { return RV({-1, 1, -2, 2, -3, 3, -4, 4}); };
205  auto dy_fn = [this](const float x) { return RV({0, -2, 2, -3, 3, -4, 4}); };
206  auto dx_fn = [this](const float x, const float dy) {
207    return -(1 / (x * x)) * dy;
208  };
209  TestCWiseGrad<float>(INV, x_fn, dy_fn, dx_fn);
210}
211
212TEST_F(CWiseUnaryGradTest, Reciprocal_Complex) {
213  auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); };
214  auto dy_fn = [this](const complex64 x) {
215    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
216  };
217  auto dx_fn = [this](const complex64 x, const complex64 dy) {
218    return -conjugate(one_ / (x * x)) * dy;
219  };
220  TestCWiseGrad<complex64>(INV, x_fn, dy_fn, dx_fn);
221}
222
223TEST_F(CWiseUnaryGradTest, Square) {
224  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
225  auto dy_fn = [this](const float x) { return RV({0, -7, 7, -8, 8, -9, 9}); };
226  auto dx_fn = [this](const float x, const float dy) { return 2 * x * dy; };
227  TestCWiseGrad<float>(SQUARE, x_fn, dy_fn, dx_fn);
228}
229
230TEST_F(CWiseUnaryGradTest, Square_Complex) {
231  auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); };
232  auto dy_fn = [this](const complex64& x) {
233    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
234  };
235  auto dx_fn = [this](const complex64& x, const complex64& dy) {
236    return conjugate(complex64(2, 0) * x) * dy;
237  };
238  TestCWiseGrad<complex64>(SQUARE, x_fn, dy_fn, dx_fn);
239}
240
241TEST_F(CWiseUnaryGradTest, Sqrt) {
242  auto x_fn = [this](const int i) { return RV({0, 1, 2, 3, 4, 5, 6, 7}); };
243  auto dy_fn = [this](const float x) { return x + RV({8, 9, 10, 11, 12, 13}); };
244  auto dx_fn = [this](const float x, const float dy) {
245    return dy * 0.5 * (1.0 / std::sqrt(x));
246  };
247  TestCWiseGrad<float>(SQRT, x_fn, dy_fn, dx_fn);
248}
249
250TEST_F(CWiseUnaryGradTest, Sqrt_Complex) {
251  auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); };
252  auto dy_fn = [this](const complex64& x) {
253    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
254  };
255  auto dx_fn = [this](const complex64& x, const complex64& dy) {
256    return conjugate(complex64(0.5, 0) / std::sqrt(x)) * dy;
257  };
258  TestCWiseGrad<complex64>(SQRT, x_fn, dy_fn, dx_fn);
259}
260
261TEST_F(CWiseUnaryGradTest, Rsqrt) {
262  auto x_fn = [this](const int i) { return RV({1, 2, 3, 4, 5, 6, 7, 8}); };
263  auto dy_fn = [this](const float x) { return x + RV({8, 9, 10, 11, 12, 13}); };
264  auto dx_fn = [this](const float x, const float dy) {
265    return dy * -0.5 * (1 / std::sqrt(x)) * (1 / x);
266  };
267  TestCWiseGrad<float>(RSQRT, x_fn, dy_fn, dx_fn);
268}
269
270TEST_F(CWiseUnaryGradTest, Rsqrt_Complex) {
271  auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); };
272  auto dy_fn = [this](const complex64& x) {
273    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
274  };
275  auto dx_fn = [this](const complex64& x, const complex64& dy) {
276    return conjugate(complex64(-0.5, 0) / std::sqrt(x) / x) * dy;
277  };
278  TestCWiseGrad<complex64>(RSQRT, x_fn, dy_fn, dx_fn);
279}
280
281TEST_F(CWiseUnaryGradTest, Exp) {
282  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
283  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
284  auto dx_fn = [this](const float x, const float dy) {
285    return dy * std::exp(x);
286  };
287  TestCWiseGrad<float>(EXP, x_fn, dy_fn, dx_fn);
288}
289
290TEST_F(CWiseUnaryGradTest, Exp_Complex) {
291  auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); };
292  auto dy_fn = [this](const complex64& x) {
293    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
294  };
295  auto dx_fn = [this](const complex64& x, const complex64& dy) {
296    return dy * conjugate(std::exp(x));
297  };
298  TestCWiseGrad<complex64>(EXP, x_fn, dy_fn, dx_fn);
299}
300
301TEST_F(CWiseUnaryGradTest, Expm1) {
302  auto x_fn = [this](const int i) { return RV({0, -1, 1e-6, 1, -2, 3, 100}); };
303  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
304  auto dx_fn = [this](const float x, const float dy) {
305    return dy * std::exp(x);
306  };
307  TestCWiseGrad<float>(EXPM1, x_fn, dy_fn, dx_fn);
308}
309
310TEST_F(CWiseUnaryGradTest, Expm1_Complex) {
311  auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); };
312  auto dy_fn = [this](const complex64& x) {
313    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
314  };
315  auto dx_fn = [this](const complex64& x, const complex64& dy) {
316    return dy * conjugate(std::exp(x));
317  };
318  TestCWiseGrad<complex64>(EXPM1, x_fn, dy_fn, dx_fn);
319}
320
321TEST_F(CWiseUnaryGradTest, Log) {
322  auto x_fn = [this](const int i) { return RV({-1, 1, -2, 2, -3, 3, -4, 4}); };
323  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
324  auto dx_fn = [this](const float x, const float dy) { return dy * (1.0 / x); };
325  TestCWiseGrad<float>(LOG, x_fn, dy_fn, dx_fn);
326}
327
328TEST_F(CWiseUnaryGradTest, Log_Complex) {
329  auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); };
330  auto dy_fn = [this](const complex64& x) {
331    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
332  };
333  auto dx_fn = [this](const complex64& x, const complex64& dy) {
334    return dy * conjugate(one_ / x);
335  };
336  TestCWiseGrad<complex64>(LOG, x_fn, dy_fn, dx_fn);
337}
338
339TEST_F(CWiseUnaryGradTest, Log1p) {
340  auto x_fn = [this](const int i) { return RV({0, 1e-6, 1, 2, 3, 4, 100}); };
341  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
342  auto dx_fn = [this](const float x, const float dy) {
343    return dy * (1.0 / (1.0 + x));
344  };
345  TestCWiseGrad<float>(LOG1P, x_fn, dy_fn, dx_fn);
346}
347
348TEST_F(CWiseUnaryGradTest, Log1p_Complex) {
349  auto x_fn = [this](const int i) {
350    return CRV({{0, 0}, {1e-6, 0}, {2, -1}, {1, 2}, {3, 4}});
351  };
352  auto dy_fn = [this](const complex64& x) {
353    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
354  };
355  auto dx_fn = [this](const complex64& x, const complex64& dy) {
356    return dy / (one_ + conjugate(x));
357  };
358  TestCWiseGrad<complex64>(LOG1P, x_fn, dy_fn, dx_fn);
359}
360
361TEST_F(CWiseUnaryGradTest, Sinh) {
362  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
363  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
364  auto dx_fn = [this](const float x, const float dy) {
365    return dy * std::cosh(x);
366  };
367  TestCWiseGrad<float>(SINH, x_fn, dy_fn, dx_fn);
368}
369
370TEST_F(CWiseUnaryGradTest, Sinh_Complex) {
371  auto x_fn = [this](const int i) {
372    return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
373  };
374  auto dy_fn = [this](const complex64& x) {
375    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
376  };
377  auto dx_fn = [this](const complex64& x, const complex64& dy) {
378    return dy * conjugate(std::cosh(x));
379  };
380  TestCWiseGrad<complex64>(SINH, x_fn, dy_fn, dx_fn);
381}
382
383TEST_F(CWiseUnaryGradTest, Cosh) {
384  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
385  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
386  auto dx_fn = [this](const float x, const float dy) {
387    return dy * std::sinh(x);
388  };
389  TestCWiseGrad<float>(COSH, x_fn, dy_fn, dx_fn);
390}
391
392TEST_F(CWiseUnaryGradTest, Cosh_Complex) {
393  auto x_fn = [this](const int i) {
394    return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
395  };
396  auto dy_fn = [this](const complex64& x) {
397    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
398  };
399  auto dx_fn = [this](const complex64& x, const complex64& dy) {
400    return dy * conjugate(std::sinh(x));
401  };
402  TestCWiseGrad<complex64>(COSH, x_fn, dy_fn, dx_fn);
403}
404
405TEST_F(CWiseUnaryGradTest, Tanh) {
406  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
407  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
408  auto dx_fn = [this](const float x, const float dy) {
409    const float y = std::tanh(x);
410    return dy * (1.0 - y * y);
411  };
412  TestCWiseGrad<float>(TANH, x_fn, dy_fn, dx_fn);
413}
414
415TEST_F(CWiseUnaryGradTest, Tanh_Complex) {
416  auto x_fn = [this](const int i) {
417    return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
418  };
419  auto dy_fn = [this](const complex64& x) {
420    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
421  };
422  auto dx_fn = [this](const complex64& x, const complex64& dy) {
423    const complex64 y = std::tanh(x);
424    return dy * conjugate((one_ - y * y));
425  };
426  TestCWiseGrad<complex64>(TANH, x_fn, dy_fn, dx_fn);
427}
428
429TEST_F(CWiseUnaryGradTest, Asinh) {
430  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
431  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
432  auto dx_fn = [this](const float x, const float dy) {
433    auto y = std::asinh(x);
434    return dy / std::cosh(y);
435  };
436  TestCWiseGrad<float>(ASINH, x_fn, dy_fn, dx_fn);
437}
438
439TEST_F(CWiseUnaryGradTest, Asinh_Complex) {
440  auto x_fn = [this](const int i) {
441    return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
442  };
443  auto dy_fn = [this](const complex64& x) {
444    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
445  };
446  auto dx_fn = [this](const complex64& x, const complex64& dy) {
447    auto y = std::asinh(x);
448    return dy / conjugate(std::cosh(y));
449  };
450  TestCWiseGrad<complex64>(ASINH, x_fn, dy_fn, dx_fn);
451}
452
453TEST_F(CWiseUnaryGradTest, Acosh) {
454  auto x_fn = [this](const int i) { return RV({1, 2, 3, 4, 5, 6, 7}); };
455  auto dy_fn = [this](const float x) {
456    return x + RV({8, 9, 10, 11, 12, 13, 14});
457  };
458  auto dx_fn = [this](const float x, const float dy) {
459    auto y = std::acosh(x);
460    return dy / std::sinh(y);
461  };
462  TestCWiseGrad<float>(ACOSH, x_fn, dy_fn, dx_fn);
463}
464
465TEST_F(CWiseUnaryGradTest, Acosh_Complex) {
466  auto x_fn = [this](const int i) {
467    return CRV({{1, 1}, {2, 1}, {1, 4}, {1, 2}, {3, 4}});
468  };
469  auto dy_fn = [this](const complex64& x) {
470    return x + CRV({{2, 2}, {3, 3}, {1, 4}});
471  };
472  auto dx_fn = [this](const complex64& x, const complex64& dy) {
473    auto y = std::acosh(x);
474    return dy / conjugate(std::sinh(y));
475  };
476  TestCWiseGrad<complex64>(ACOSH, x_fn, dy_fn, dx_fn);
477}
478
479TEST_F(CWiseUnaryGradTest, Atanh) {
480  auto x_fn = [this](const int i) { return RV({0, -0.5, 0.5, -0.1, 0.1}); };
481  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
482  auto dx_fn = [this](const float x, const float dy) {
483    return dy * (1. / (1. - x * x));
484  };
485  TestCWiseGrad<float>(ATANH, x_fn, dy_fn, dx_fn);
486}
487
488TEST_F(CWiseUnaryGradTest, Atanh_Complex) {
489  auto x_fn = [this](const int i) {
490    return CRV({{0.1, 0}, {0, 0.1}, {0.2, -0.1}, {0.1, 0.2}, {0.3, 0.4}});
491  };
492  auto dy_fn = [this](const complex64& x) {
493    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
494  };
495  auto dx_fn = [this](const complex64& x, const complex64& dy) {
496    return dy / conjugate(one_ - x * x);
497  };
498  TestCWiseGrad<complex64>(ATANH, x_fn, dy_fn, dx_fn);
499}
500
501TEST_F(CWiseUnaryGradTest, Sigmoid) {
502  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
503  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
504  auto dx_fn = [this](const float x, const float dy) {
505    const float y = 1.0 / (1.0 + std::exp(-x));
506    return dy * y * (1.0 - y);
507  };
508  TestCWiseGrad<float>(SIGMOID, x_fn, dy_fn, dx_fn);
509}
510
511TEST_F(CWiseUnaryGradTest, Sigmoid_Complex) {
512  auto x_fn = [this](const int i) {
513    return CRV({{1, 0}, {0, 0}, {2, -1}, {1, 2}, {3, 4}});
514  };
515  auto dy_fn = [this](const complex64& x) {
516    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
517  };
518  auto dx_fn = [this](const complex64& x, const complex64& dy) {
519    const complex64 y = one_ / (one_ + std::exp(-x));
520    return dy * conjugate(y * (one_ - y));
521  };
522  TestCWiseGrad<complex64>(SIGMOID, x_fn, dy_fn, dx_fn);
523}
524
525TEST_F(CWiseUnaryGradTest, Sign) {
526  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
527  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
528  auto dx_fn = [this](const float x, const float dy) { return 0.0; };
529  TestCWiseGrad<float>(SIGN, x_fn, dy_fn, dx_fn);
530}
531
532TEST_F(CWiseUnaryGradTest, Sin) {
533  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
534  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
535  auto dx_fn = [this](const float x, const float dy) {
536    return dy * std::cos(x);
537  };
538  TestCWiseGrad<float>(SIN, x_fn, dy_fn, dx_fn);
539}
540
541TEST_F(CWiseUnaryGradTest, Sin_Complex) {
542  auto x_fn = [this](const int i) {
543    return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
544  };
545  auto dy_fn = [this](const complex64& x) {
546    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
547  };
548  auto dx_fn = [this](const complex64& x, const complex64& dy) {
549    return dy * conjugate(std::cos(x));
550  };
551  TestCWiseGrad<complex64>(SIN, x_fn, dy_fn, dx_fn);
552}
553
554TEST_F(CWiseUnaryGradTest, Cos) {
555  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
556  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
557  auto dx_fn = [this](const float x, const float dy) {
558    return dy * -1.0 * std::sin(x);
559  };
560  TestCWiseGrad<float>(COS, x_fn, dy_fn, dx_fn);
561}
562
563TEST_F(CWiseUnaryGradTest, Cos_Complex) {
564  auto x_fn = [this](const int i) {
565    return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
566  };
567  auto dy_fn = [this](const complex64& x) {
568    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
569  };
570  auto dx_fn = [this](const complex64& x, const complex64& dy) {
571    return dy * conjugate(-std::sin(x));
572  };
573  TestCWiseGrad<complex64>(COS, x_fn, dy_fn, dx_fn);
574}
575
576TEST_F(CWiseUnaryGradTest, Asin) {
577  auto x_fn = [this](const int i) { return RV({0, -0.5, 0.5, -1, 1}); };
578  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
579  auto dx_fn = [this](const float x, const float dy) {
580    return dy * (1.0 / std::sqrt(1.0 - x * x));
581  };
582  TestCWiseGrad<float>(ASIN, x_fn, dy_fn, dx_fn);
583}
584
585TEST_F(CWiseUnaryGradTest, Asin_Complex) {
586  auto x_fn = [this](const int i) {
587    return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
588  };
589  auto dy_fn = [this](const complex64& x) {
590    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
591  };
592  auto dx_fn = [this](const complex64& x, const complex64& dy) {
593    return dy / conjugate(std::sqrt(one_ - x * x));
594  };
595  // TODO(kbsriram)
596  // Enable test when the asin kernel supports complex numbers
597  if (false) {
598    TestCWiseGrad<complex64>(ASIN, x_fn, dy_fn, dx_fn);
599  }
600}
601
602TEST_F(CWiseUnaryGradTest, Acos) {
603  auto x_fn = [this](const int i) { return RV({0, -0.5, 0.5, -1, 1}); };
604  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
605  auto dx_fn = [this](const float x, const float dy) {
606    return dy * (-1.0 / std::sqrt(1.0 - x * x));
607  };
608  TestCWiseGrad<float>(ACOS, x_fn, dy_fn, dx_fn);
609}
610
611TEST_F(CWiseUnaryGradTest, Acos_Complex) {
612  auto x_fn = [this](const int i) {
613    return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
614  };
615  auto dy_fn = [this](const complex64& x) {
616    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
617  };
618  auto dx_fn = [this](const complex64& x, const complex64& dy) {
619    return dy / -conjugate(std::sqrt(one_ - x * x));
620  };
621  // TODO(kbsriram)
622  // Add test when the acos kernel supports complex numbers
623  if (false) {
624    TestCWiseGrad<complex64>(ACOS, x_fn, dy_fn, dx_fn);
625  }
626}
627
628TEST_F(CWiseUnaryGradTest, Tan) {
629  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
630  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
631  auto dx_fn = [this](const float x, const float dy) {
632    const float cosx = std::cos(x);
633    return dy * (1 / (cosx * cosx));
634  };
635  TestCWiseGrad<float>(TAN, x_fn, dy_fn, dx_fn);
636}
637
638TEST_F(CWiseUnaryGradTest, Tan_Complex) {
639  auto x_fn = [this](const int i) {
640    return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
641  };
642  auto dy_fn = [this](const complex64& x) {
643    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
644  };
645  auto dx_fn = [this](const complex64& x, const complex64& dy) {
646    const complex64 cosx = std::cos(x);
647    return dy / conjugate(cosx * cosx);
648  };
649  // TODO(kbsriram)
650  // Enable when tan kernel supports complex inputs
651  if (false) {
652    TestCWiseGrad<complex64>(TAN, x_fn, dy_fn, dx_fn);
653  }
654}
655
656TEST_F(CWiseUnaryGradTest, Atan) {
657  auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
658  auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
659  auto dx_fn = [this](const float x, const float dy) {
660    return dy * (1 / (1 + x * x));
661  };
662  TestCWiseGrad<float>(ATAN, x_fn, dy_fn, dx_fn);
663}
664
665TEST_F(CWiseUnaryGradTest, Atan_Complex) {
666  auto x_fn = [this](const int i) {
667    return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
668  };
669  auto dy_fn = [this](const complex64& x) {
670    return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
671  };
672  auto dx_fn = [this](const complex64& x, const complex64& dy) {
673    return dy / (one_ + x * x);
674  };
675  // TODO(kbsriram)
676  // Add test when the atan kernel supports complex numbers
677  if (false) {
678    TestCWiseGrad<complex64>(ATAN, x_fn, dy_fn, dx_fn);
679  }
680}
681
682class CWiseUnaryComplexGradTest : public ::testing::Test {
683 protected:
684  CWiseUnaryComplexGradTest()
685      : scope_(Scope::NewRootScope().WithDevice("/cpu:0")) {}
686
687  enum UnaryOpType { REAL, IMAG, ANGLE, CONJ };
688
689  void TestCWiseGradComplex(UnaryOpType op_type, const Tensor& x,
690                            const Tensor& dy, const Tensor& dx_expected) {
691    Output y;
692    switch (op_type) {
693      case REAL:
694        y = Real(scope_, x);
695        break;
696      case IMAG:
697        y = Imag(scope_, x);
698        break;
699      case ANGLE:
700        y = Angle(scope_, x);
701        break;
702      case CONJ:
703        y = Conj(scope_, x);
704        break;
705    }
706
707    std::vector<Output> grad_outputs;
708    TF_ASSERT_OK(test::CallGradFunction(
709        scope_, Operation(y.node()), {ops::Const(scope_, dy)}, &grad_outputs));
710    Tensor dx;
711    test::GetTensor(scope_, grad_outputs[0], &dx);
712    test::ExpectClose(dx, dx_expected);
713  }
714
715  Scope scope_;
716};
717
718TEST_F(CWiseUnaryComplexGradTest, Real) {
719  Tensor x = test::AsTensor<complex64>(
720      {{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3});
721  Tensor dy = test::AsTensor<float>({11, -12, 13, -14, 15, -16}, {2, 3});
722  Tensor dx_expected = test::AsTensor<complex64>(
723      {{11, 0}, {-12, 0}, {13, 0}, {-14, 0}, {15, 0}, {-16, 0}}, {2, 3});
724  TestCWiseGradComplex(REAL, x, dy, dx_expected);
725}
726
727TEST_F(CWiseUnaryComplexGradTest, Imag) {
728  Tensor x = test::AsTensor<complex64>(
729      {{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3});
730  Tensor dy = test::AsTensor<float>({11, -12, 13, -14, 15, -16}, {2, 3});
731  Tensor dx_expected = test::AsTensor<complex64>(
732      {{0, 11}, {0, -12}, {0, 13}, {0, -14}, {0, 15}, {0, -16}}, {2, 3});
733  TestCWiseGradComplex(IMAG, x, dy, dx_expected);
734}
735
736TEST_F(CWiseUnaryComplexGradTest, Angle) {
737  Tensor x = test::AsTensor<complex64>(
738      {{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3});
739  Tensor dy = test::AsTensor<float>({11, -12, 13, -14, 15, -16}, {2, 3});
740  Tensor dx_expected =
741      test::AsTensor<complex64>({{5.5, 5.5},
742                                 {3, 3},
743                                 {2.1666666666666665, 2.1666666666666665},
744                                 {1.75, 1.75},
745                                 {0.9375, 0.9375},
746                                 {0.8888888888888888, 0.8888888888888888}},
747                                {2, 3});
748  TestCWiseGradComplex(ANGLE, x, dy, dx_expected);
749}
750
751TEST_F(CWiseUnaryComplexGradTest, Conj) {
752  Tensor x = test::AsTensor<complex64>(
753      {{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3});
754  Tensor dy = test::AsTensor<complex64>(
755      {{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3});
756  Tensor dx_expected = test::AsTensor<complex64>(
757      {{1, 1}, {-2, -2}, {3, 3}, {-4, -4}, {8, 8}, {-9, -9}}, {2, 3});
758  TestCWiseGradComplex(CONJ, x, dy, dx_expected);
759}
760
761class MathGradTest : public ::testing::Test {
762 protected:
763  MathGradTest() : root_(Scope::NewRootScope().WithDevice("/cpu:0")) {}
764
765  void TestMatMulGrad(const bool is_batch, const bool t_x, const bool t_y) {
766    // Generate random test data.
767    std::vector<Tensor> data;
768    RandMatMulGradData(is_batch, t_x, t_y, &data);
769    auto x = Const(root_, data[0]);
770    auto y = Const(root_, data[1]);
771    auto dz = Const(root_, data[2]);
772
773    std::vector<Tensor> grad_outputs;
774    ComputeMatMulGrad(is_batch, x, t_x, y, t_y, dz, &grad_outputs);
775
776    if (!t_x && !t_y) {
777      test::ExpectClose(grad_outputs[0],
778                        ComputeMatMul(is_batch, dz, false, y, true));
779      test::ExpectClose(grad_outputs[1],
780                        ComputeMatMul(is_batch, x, true, dz, false));
781    } else if (t_x && !t_y) {
782      test::ExpectClose(grad_outputs[0],
783                        ComputeMatMul(is_batch, y, false, dz, true));
784      test::ExpectClose(grad_outputs[1],
785                        ComputeMatMul(is_batch, x, false, dz, false));
786    } else if (!t_x && t_y) {
787      test::ExpectClose(grad_outputs[0],
788                        ComputeMatMul(is_batch, dz, false, y, false));
789      test::ExpectClose(grad_outputs[1],
790                        ComputeMatMul(is_batch, dz, true, x, false));
791    } else {
792      test::ExpectClose(grad_outputs[0],
793                        ComputeMatMul(is_batch, y, true, dz, true));
794      test::ExpectClose(grad_outputs[1],
795                        ComputeMatMul(is_batch, dz, true, x, true));
796    }
797  }
798
799  void ComputeMatMulGrad(const bool is_batch, const Output& x, const bool t_x,
800                         const Output& y, const bool t_y, const Output& dz,
801                         std::vector<Tensor>* out) {
802    // Compute forward MatMul: z = MatMul(x, y).
803    Output z;
804    if (is_batch) {
805      z = BatchMatMul(root_, x, y, BatchMatMul::AdjX(t_x).AdjY(t_y));
806    } else {
807      z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y));
808    }
809    TF_ASSERT_OK(root_.status());
810    CHECK_NOTNULL(z.node());
811    std::vector<Output> grad_outputs;
812    // Call MatMulGrad which populates 'grad_outputs'.
813    TF_ASSERT_OK(test::CallGradFunction(root_, Operation(z.node()), {dz},
814                                        &grad_outputs));
815    ASSERT_EQ(2, grad_outputs.size());
816    // Run graph and return MatMul gradient tensors for 'dx' and 'dy' in 'out'.
817    test::GetTensors(root_, {grad_outputs[0], grad_outputs[1]}, out);
818  }
819
820  Tensor ComputeMatMul(const bool is_batch, const Output& x, const bool t_x,
821                       const Output& y, const bool t_y) {
822    Output z;
823    if (is_batch) {
824      z = BatchMatMul(root_, x, y, BatchMatMul::AdjX(t_x).AdjY(t_y));
825    } else {
826      z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y));
827    }
828    TF_EXPECT_OK(root_.status());
829    Tensor out;
830    test::GetTensor(root_, z, &out);
831    return out;
832  }
833
834  void RandMatMulGradData(const bool is_batch, const bool tx, const bool ty,
835                          std::vector<Tensor>* data) {
836    // Choose a random batch size in [1, 4]
837    const int b = 1 + (random::New64() % 4);
838    // z = MatMul(x, y)
839    const int m = Rand();
840    const int k = Rand();
841    const int n = Rand();
842
843    TensorShape x_shape;
844    if (is_batch) {
845      // x.shape = [b, m, k]
846      x_shape = tx ? TensorShape({b, k, m}) : TensorShape({b, m, k});
847    } else {
848      // x.shape = [m, k]
849      x_shape = tx ? TensorShape({k, m}) : TensorShape({m, k});
850    }
851    data->emplace_back(DT_FLOAT, x_shape);
852    RandTensor(&data->back());
853
854    TensorShape y_shape;
855    if (is_batch) {
856      // y.shape = [b, k, n]
857      y_shape = ty ? TensorShape({b, n, k}) : TensorShape({b, k, n});
858    } else {
859      // y.shape = [k, n]
860      y_shape = ty ? TensorShape({n, k}) : TensorShape({k, n});
861    }
862    data->emplace_back(DT_FLOAT, y_shape);
863    RandTensor(&data->back());
864
865    TensorShape z_shape;
866    if (is_batch) {
867      // z.shape = [b, m, n]
868      z_shape = TensorShape({b, m, n});
869    } else {
870      // z.shape = [m, n]
871      z_shape = TensorShape({m, n});
872    }
873    data->emplace_back(DT_FLOAT, z_shape);
874    RandTensor(&data->back());
875  }
876
877  void RandTensor(Tensor* t) {
878    test::FillFn<float>(
879        t, [this](const int i) { return static_cast<float>(Rand()); });
880  }
881
882  int Rand() { return 1 + (random::New64() % 10); }
883
884  Scope root_;
885};
886
887TEST_F(MathGradTest, MatMulGrad_NoTranspose) {
888  TestMatMulGrad(false, false, false);
889}
890
891TEST_F(MathGradTest, MatMulGrad_TransposeX) {
892  TestMatMulGrad(false, true, false);
893}
894
895TEST_F(MathGradTest, MatMulGrad_TransposeY) {
896  TestMatMulGrad(false, false, true);
897}
898
899TEST_F(MathGradTest, MatMulGrad_TransposeX_TransposeY) {
900  TestMatMulGrad(false, true, true);
901}
902
903TEST_F(MathGradTest, BatchMatMulGrad_NoTranspose) {
904  TestMatMulGrad(true, false, false);
905}
906
907TEST_F(MathGradTest, BatchMatMulGrad_TransposeX) {
908  TestMatMulGrad(true, true, false);
909}
910
911TEST_F(MathGradTest, BatchMatMulGrad_TransposeY) {
912  TestMatMulGrad(true, false, true);
913}
914
915TEST_F(MathGradTest, BatchMatMulGrad_TransposeX_TransposeY) {
916  TestMatMulGrad(true, true, true);
917}
918
919class NaryGradTest : public ::testing::Test {
920 protected:
921  NaryGradTest() : scope_(Scope::NewRootScope().WithDevice("/cpu:0")) {}
922
923  void RunTest(const OutputList& xs, const std::vector<TensorShape>& x_shapes,
924               const OutputList& ys, const std::vector<TensorShape>& y_shapes) {
925    TF_ASSERT_OK(scope_.status());
926    float max_error;
927    TF_ASSERT_OK((ComputeGradientError<float, float, float>(
928        scope_, xs, x_shapes, ys, y_shapes, &max_error)));
929    EXPECT_LT(max_error, 1e-3);
930  }
931
932  void RunTest(const Output& x, const Tensor& x_init_value, const Output& y,
933               const TensorShape& y_shape) {
934    TF_ASSERT_OK(scope_.status());
935    float max_error;
936    TF_ASSERT_OK((ComputeGradientError<float, float, float>(
937        scope_, x, x_init_value, y, y_shape, &max_error)));
938    EXPECT_LT(max_error, 1e-3);
939  }
940
941  Scope scope_;
942};
943
944TEST_F(NaryGradTest, Sum) {
945  TensorShape x_shape({2, 3, 5, 7});
946  auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
947  auto y = Sum(scope_, x, {1, -1});
948  // y's shape is the result of reducing x along axes 1 and -1 (= 3)
949  TensorShape y_shape({2, 5});
950  RunTest({x}, {x_shape}, {y}, {y_shape});
951}
952
953TEST_F(NaryGradTest, Mean) {
954  TensorShape x_shape({2, 3, 5, 7});
955  auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
956  auto y = Mean(scope_, x, {1, -1});
957  // y's shape is the result of reducing x along axes 1 and -1 (= 3)
958  TensorShape y_shape({2, 5});
959  RunTest({x}, {x_shape}, {y}, {y_shape});
960}
961
962TEST_F(NaryGradTest, Min) {
963  TensorShape x_shape({2, 3});
964  auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
965  auto y = Min(scope_, x, {-1});
966  // y's shape is the result of reducing x along axes -1 (= 1)
967  TensorShape y_shape({2});
968  Tensor x_init_value =
969      test::AsTensor<float>({0.5f, 0.7f, 0.2f, 1.0f, 1.5f, -2.8f}, x_shape);
970  RunTest(x, x_init_value, y, y_shape);
971}
972
973TEST_F(NaryGradTest, Max) {
974  TensorShape x_shape({2, 3});
975  auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
976  auto y = Max(scope_, x, {-1});
977  // y's shape is the result of reducing x along axes -1 (= 1)
978  TensorShape y_shape({2});
979  Tensor x_init_value =
980      test::AsTensor<float>({0.5f, 0.7f, 0.2f, 1.0f, 1.5f, -2.8f}, x_shape);
981  RunTest(x, x_init_value, y, y_shape);
982}
983
984TEST_F(NaryGradTest, MinMulti) {
985  // Test gradient when there are multiple minima.
986  // Note that we cannot directly use a test Tensor with multiple
987  // minima, as the numeric estimator will calculate incorrect
988  // gradients when perturbing each entry in the Tensor (which then
989  // changes how many minima exist.)
990  // Instead, we use a single input that broadcast-multiplies a larger
991  // tensor with equal values, and apply reduce_min to the multiplied
992  // result.
993  TensorShape x_shape({1});
994  auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
995  auto all_same = Mul(scope_, Const(scope_, {1.f, 1.f, 1.f}), x);
996  auto y = Min(scope_, all_same, {0});
997  // y is a [3] shaped tensor reduced along dimension 0, so it is [1] shaped
998  TensorShape y_shape({1});
999  RunTest({x}, {x_shape}, {y}, {y_shape});
1000}
1001
1002TEST_F(NaryGradTest, MaxMulti) {
1003  TensorShape x_shape({1});
1004  auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
1005  auto all_same = Mul(scope_, Const(scope_, {1.f, 1.f, 1.f}), x);
1006  auto y = Max(scope_, all_same, {0});
1007  TensorShape y_shape({1});
1008  RunTest({x}, {x_shape}, {y}, {y_shape});
1009}
1010
1011TEST_F(NaryGradTest, AddN) {
1012  TensorShape shape({3, 2, 5});
1013  std::vector<Output> xs;
1014  xs.push_back(Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)));
1015  xs.push_back(Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)));
1016  xs.push_back(Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)));
1017  auto y = AddN(scope_, xs);
1018  RunTest(xs, {shape, shape, shape}, {y}, {shape});
1019}
1020
1021TEST_F(NaryGradTest, Add) {
1022  TensorShape x1_shape({3, 2, 5});
1023  TensorShape x2_shape({2, 5});
1024  auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x1_shape));
1025  auto x2 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x2_shape));
1026  auto y = Add(scope_, x1, x2);
1027  RunTest({x1, x2}, {x1_shape, x2_shape}, {y}, {x1_shape});
1028}
1029
1030TEST_F(NaryGradTest, Sub) {
1031  TensorShape x1_shape({3, 2, 5});
1032  TensorShape x2_shape({2, 5});
1033  auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x1_shape));
1034  auto x2 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x2_shape));
1035  auto y = Sub(scope_, x1, x2);
1036  RunTest({x1, x2}, {x1_shape, x2_shape}, {y}, {x1_shape});
1037}
1038
1039TEST_F(NaryGradTest, Mul) {
1040  TensorShape x1_shape({3, 2, 5});
1041  TensorShape x2_shape({2, 5});
1042  auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x1_shape));
1043  auto x2 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x2_shape));
1044  auto y = Mul(scope_, x1, x2);
1045  RunTest({x1, x2}, {x1_shape, x2_shape}, {y}, {x1_shape});
1046}
1047
1048TEST_F(NaryGradTest, Div) {
1049  TensorShape x_shape({3, 2, 5});
1050  auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
1051  // Test x / (1 + |x|) rather than x_1 / x_2 to avoid triggering large
1052  // division errors in the numeric estimator used by the gradient checker.
1053  auto y = Div(scope_, x, Add(scope_, Const<float>(scope_, 1), Abs(scope_, x)));
1054  RunTest({x}, {x_shape}, {y}, {x_shape});
1055}
1056
1057TEST_F(NaryGradTest, RealDiv) {
1058  TensorShape x_shape({3, 2, 5});
1059  auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
1060  // Test x / (1 + |x|) rather than x_1 / x_2 to avoid triggering large
1061  // division errors in the numeric estimator used by the gradient checker.
1062  auto y =
1063      RealDiv(scope_, x, Add(scope_, Const<float>(scope_, 1), Abs(scope_, x)));
1064  RunTest({x}, {x_shape}, {y}, {x_shape});
1065}
1066
1067TEST_F(NaryGradTest, SquaredDifference) {
1068  TensorShape x1_shape({3, 2, 5});
1069  TensorShape x2_shape({2, 5});
1070  auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x1_shape));
1071  auto x2 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x2_shape));
1072  auto y = SquaredDifference(scope_, x1, x2);
1073  RunTest({x1, x2}, {x1_shape, x2_shape}, {y}, {x1_shape});
1074}
1075
1076TEST_F(NaryGradTest, Maximum) {
1077  TensorShape shape({3, 2});
1078  auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
1079  auto y = Maximum(scope_, x, Const(scope_, 1.0f));
1080  // Select values away from 1.0f to avoid instability when computing
1081  // finite differences.
1082  Tensor x_init_value =
1083      test::AsTensor<float>({0.5f, 1.5f, -1.2f, 3.0f, 0.1f, 2.8f}, {3, 2});
1084  RunTest(x, x_init_value, y, shape);
1085}
1086
1087TEST_F(NaryGradTest, Minimum) {
1088  TensorShape shape({3, 2});
1089  auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
1090  auto y = Minimum(scope_, x, Const(scope_, 1.0f));
1091  // Select values away from 1.0f to avoid instability when computing
1092  // finite differences.
1093  Tensor x_init_value =
1094      test::AsTensor<float>({0.5f, 1.5f, -1.2f, 3.0f, 0.1f, 2.8f}, {3, 2});
1095  RunTest(x, x_init_value, y, shape);
1096}
1097
1098}  // namespace
1099}  // namespace tensorflow
1100