math_grad.cc revision 8624ecc9e827a40f9b514ff6b8ed925390ca79cc
1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/cc/ops/array_ops_internal.h"
17#include "tensorflow/cc/ops/math_ops_internal.h"
18#include "tensorflow/cc/ops/standard_ops.h"
19
20#include "tensorflow/cc/framework/grad_op_registry.h"
21
22namespace tensorflow {
23namespace ops {
24namespace {
25
26// Logical operations have no gradients.
27REGISTER_NO_GRADIENT_OP("Less");
28REGISTER_NO_GRADIENT_OP("LessEqual");
29REGISTER_NO_GRADIENT_OP("Greater");
30REGISTER_NO_GRADIENT_OP("GreaterEqual");
31REGISTER_NO_GRADIENT_OP("Equal");
32REGISTER_NO_GRADIENT_OP("ApproximateEqual");
33REGISTER_NO_GRADIENT_OP("NotEqual");
34REGISTER_NO_GRADIENT_OP("LogicalAnd");
35REGISTER_NO_GRADIENT_OP("LogicalOr");
36REGISTER_NO_GRADIENT_OP("LogicalNot");
37
38// Conjugate helper function returns the conjugate of an Output if it
39// is complex valued.
40Output ConjugateHelper(const Scope& scope, const Output& out) {
41  DataType dtype = out.type();
42  if (dtype == DT_COMPLEX64 || dtype == DT_COMPLEX128) {
43    return Conj(scope, out);
44  } else {
45    return out;
46  }
47}
48
49// TODO(andydavis) Add control dependencies to gradient functions (as needed).
50
51Status AbsGrad(const Scope& scope, const Operation& op,
52               const std::vector<Output>& grad_inputs,
53               std::vector<Output>* grad_outputs) {
54  // dx = dy * sign(x)
55  grad_outputs->push_back(Mul(scope, grad_inputs[0], Sign(scope, op.input(0))));
56  return scope.status();
57}
58REGISTER_GRADIENT_OP("Abs", AbsGrad);
59
60Status NegGrad(const Scope& scope, const Operation& op,
61               const std::vector<Output>& grad_inputs,
62               std::vector<Output>* grad_outputs) {
63  // dx = -dy;
64  grad_outputs->push_back(Neg(scope, grad_inputs[0]));
65  return scope.status();
66}
67REGISTER_GRADIENT_OP("Neg", NegGrad);
68
69Status InvGrad(const Scope& scope, const Operation& op,
70               const std::vector<Output>& grad_inputs,
71               std::vector<Output>* grad_outputs) {
72  // Use the built-in operator.
73  grad_outputs->push_back(
74      internal::ReciprocalGrad(scope, op.output(0), grad_inputs[0]));
75  return scope.status();
76}
77REGISTER_GRADIENT_OP("Inv", InvGrad);
78REGISTER_GRADIENT_OP("Reciprocal", InvGrad);
79
80Status SquareGrad(const Scope& scope, const Operation& op,
81                  const std::vector<Output>& grad_inputs,
82                  std::vector<Output>* grad_outputs) {
83  // dy/dx = (2 * x)
84  auto two = Cast(scope, Const(scope, 2), op.input(0).type());
85  auto dydx = Mul(scope, two, op.input(0));
86  // grad(x) = grad(y) * conj(dy/dx)
87  grad_outputs->push_back(
88      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
89  return scope.status();
90}
91REGISTER_GRADIENT_OP("Square", SquareGrad);
92
93Status SqrtGrad(const Scope& scope, const Operation& op,
94                const std::vector<Output>& grad_inputs,
95                std::vector<Output>* grad_outputs) {
96  // Use the built-in operator.
97  grad_outputs->push_back(
98      internal::SqrtGrad(scope, op.output(0), grad_inputs[0]));
99  return scope.status();
100}
101REGISTER_GRADIENT_OP("Sqrt", SqrtGrad);
102
103Status RsqrtGrad(const Scope& scope, const Operation& op,
104                 const std::vector<Output>& grad_inputs,
105                 std::vector<Output>* grad_outputs) {
106  // Use the built-in operator.
107  grad_outputs->push_back(
108      internal::RsqrtGrad(scope, op.output(0), grad_inputs[0]));
109  return scope.status();
110}
111REGISTER_GRADIENT_OP("Rsqrt", RsqrtGrad);
112
113Status ExpGrad(const Scope& scope, const Operation& op,
114               const std::vector<Output>& grad_inputs,
115               std::vector<Output>* grad_outputs) {
116  // dy/dx = exp(x) = y
117  // grad(x) = grad(y) * conj(dy/dx)
118  //         = grad(y) * conj(y)
119  grad_outputs->push_back(
120      Mul(scope, grad_inputs[0], ConjugateHelper(scope, op.output(0))));
121  return scope.status();
122}
123REGISTER_GRADIENT_OP("Exp", ExpGrad);
124
125Status Expm1Grad(const Scope& scope, const Operation& op,
126                 const std::vector<Output>& grad_inputs,
127                 std::vector<Output>* grad_outputs) {
128  // y = expm1(x)
129  // dy/dx = exp(x)
130  auto dydx = Exp(scope, op.input(0));
131  // grad(x) = grad(y) * conj(dy/dx)
132  grad_outputs->push_back(
133      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
134  return scope.status();
135}
136REGISTER_GRADIENT_OP("Expm1", Expm1Grad);
137
138Status LogGrad(const Scope& scope, const Operation& op,
139               const std::vector<Output>& grad_inputs,
140               std::vector<Output>* grad_outputs) {
141  // y = log(x)
142  // dy/dx = 1 / x
143  auto dydx = Reciprocal(scope, op.input(0));
144  // grad(x) = grad(y) * conj(dy/dx)
145  grad_outputs->push_back(
146      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
147  return scope.status();
148}
149REGISTER_GRADIENT_OP("Log", LogGrad);
150
151Status Log1pGrad(const Scope& scope, const Operation& op,
152                 const std::vector<Output>& grad_inputs,
153                 std::vector<Output>* grad_outputs) {
154  // y = log1p(x)
155  // dy/dx = 1 / (1 + x)
156  auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
157  auto dydx = Reciprocal(scope, Add(scope, one, op.input(0)));
158  // grad(x) = grad(y) * conj(dy/dx)
159  grad_outputs->push_back(
160      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
161  return scope.status();
162}
163REGISTER_GRADIENT_OP("Log1p", Log1pGrad);
164
165Status SinhGrad(const Scope& scope, const Operation& op,
166                const std::vector<Output>& grad_inputs,
167                std::vector<Output>* grad_outputs) {
168  // y = sinh(x)
169  // dy/dx = cosh(x)
170  auto dydx = Cosh(scope, op.input(0));
171  // grad(x) = grad(y) * conj(dy/dx)
172  grad_outputs->push_back(
173      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
174  return scope.status();
175}
176REGISTER_GRADIENT_OP("Sinh", SinhGrad);
177
178Status CoshGrad(const Scope& scope, const Operation& op,
179                const std::vector<Output>& grad_inputs,
180                std::vector<Output>* grad_outputs) {
181  // y = cosh(x)
182  // dy/dx = sinh(x)
183  auto dydx = Sinh(scope, op.input(0));
184  // grad(x) = grad(y) * conj(dy/dx)
185  grad_outputs->push_back(
186      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
187  return scope.status();
188}
189REGISTER_GRADIENT_OP("Cosh", CoshGrad);
190
191Status TanhGrad(const Scope& scope, const Operation& op,
192                const std::vector<Output>& grad_inputs,
193                std::vector<Output>* grad_outputs) {
194  // Use the built-in operator.
195  // Note that the built-in operator does not return the conjugate of
196  // the gradient.
197  auto grad = grad_inputs[0];
198  // Optimization to avoid calculating conj(y) until the gradient is
199  // evaluated.
200  Scope grad_scope = scope.WithControlDependencies(grad);
201  auto y = ConjugateHelper(grad_scope, op.output(0));
202  grad_outputs->push_back(internal::TanhGrad(scope, y, grad));
203  return scope.status();
204}
205REGISTER_GRADIENT_OP("Tanh", TanhGrad);
206
207Status AsinhGrad(const Scope& scope, const Operation& op,
208                 const std::vector<Output>& grad_inputs,
209                 std::vector<Output>* grad_outputs) {
210  // y = asinh(x)
211  // dy/dx = 1 / cosh(y)
212  auto dydx = Reciprocal(scope, Cosh(scope, op.output(0)));
213  // grad(x) = grad(y) * conj(dy/dx)
214  grad_outputs->push_back(
215      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
216  return scope.status();
217}
218REGISTER_GRADIENT_OP("Asinh", AsinhGrad);
219
220Status AcoshGrad(const Scope& scope, const Operation& op,
221                 const std::vector<Output>& grad_inputs,
222                 std::vector<Output>* grad_outputs) {
223  // y = acosh(x)
224  // dy/dx = 1 / sinh(y)
225  auto dydx = Reciprocal(scope, Sinh(scope, op.output(0)));
226  // grad(x) = grad(y) * conj(dy/dx)
227  grad_outputs->push_back(
228      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
229  return scope.status();
230}
231REGISTER_GRADIENT_OP("Acosh", AcoshGrad);
232
233Status AtanhGrad(const Scope& scope, const Operation& op,
234                 const std::vector<Output>& grad_inputs,
235                 std::vector<Output>* grad_outputs) {
236  // y = atanh(x)
237  // dy/dx = 1 / (1 - x^2)
238  auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
239  auto dydx = Reciprocal(scope, Sub(scope, one, Square(scope, op.input(0))));
240  // grad(x) = grad(y) * conj(dy/dx)
241  grad_outputs->push_back(
242      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
243  return scope.status();
244}
245REGISTER_GRADIENT_OP("Atanh", AtanhGrad);
246
247Status SigmoidGrad(const Scope& scope, const Operation& op,
248                   const std::vector<Output>& grad_inputs,
249                   std::vector<Output>* grad_outputs) {
250  // Use the built-in operator.
251  // Note that the built-in operator does not return the conjugate of
252  // the gradient.
253  auto grad = grad_inputs[0];
254  // Optimization to avoid calculating conj(y) until the gradient is
255  // evaluated.
256  Scope grad_scope = scope.WithControlDependencies(grad);
257  auto y = ConjugateHelper(grad_scope, op.output(0));
258  grad_outputs->push_back(internal::SigmoidGrad(scope, y, grad));
259  return scope.status();
260}
261REGISTER_GRADIENT_OP("Sigmoid", SigmoidGrad);
262
263Status SignGrad(const Scope& scope, const Operation& op,
264                const std::vector<Output>& grad_inputs,
265                std::vector<Output>* grad_outputs) {
266  auto shape = Shape(scope, op.input(0));
267  auto zero = Cast(scope, Const(scope, 0.0), op.input(0).type());
268  auto dx = Fill(scope, shape, zero);
269  grad_outputs->push_back(dx);
270  return scope.status();
271}
272REGISTER_GRADIENT_OP("Sign", SignGrad);
273
274Status SinGrad(const Scope& scope, const Operation& op,
275               const std::vector<Output>& grad_inputs,
276               std::vector<Output>* grad_outputs) {
277  // y = sin(x)
278  // dy/dx = cos(x)
279  auto dydx = Cos(scope, op.input(0));
280  // grad(x) = grad(y) * conj(dy/dx)
281  grad_outputs->push_back(
282      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
283  return scope.status();
284}
285REGISTER_GRADIENT_OP("Sin", SinGrad);
286
287Status CosGrad(const Scope& scope, const Operation& op,
288               const std::vector<Output>& grad_inputs,
289               std::vector<Output>* grad_outputs) {
290  // y = cos(x)
291  // dy/dx = -sin(x)
292  auto dydx = Neg(scope, Sin(scope, op.input(0)));
293  // grad(x) = grad(y) * conj(dy/dx)
294  grad_outputs->push_back(
295      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
296  return scope.status();
297}
298REGISTER_GRADIENT_OP("Cos", CosGrad);
299
300Status AsinGrad(const Scope& scope, const Operation& op,
301                const std::vector<Output>& grad_inputs,
302                std::vector<Output>* grad_outputs) {
303  // y = asin(x)
304  // dy/dx = 1 / sqrt(1 - x^2)
305  auto x2 = Square(scope, op.input(0));
306  auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
307  auto dydx = Reciprocal(scope, Sqrt(scope, Sub(scope, one, x2)));
308  // grad(x) = grad(y) * conj(dy/dx)
309  auto dx = Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx));
310  grad_outputs->push_back(dx);
311  return scope.status();
312}
313REGISTER_GRADIENT_OP("Asin", AsinGrad);
314
315Status AcosGrad(const Scope& scope, const Operation& op,
316                const std::vector<Output>& grad_inputs,
317                std::vector<Output>* grad_outputs) {
318  // y = acos(x)
319  // dy/dx = - 1 / (1 - x * x)^1/2
320  // dx = dy * (- 1 / (1 - x * x)^1/2)
321  auto x2 = Square(scope, op.input(0));
322  auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
323  auto dydx = Neg(scope, Reciprocal(scope, Sqrt(scope, Sub(scope, one, x2))));
324  auto dx = Mul(scope, grad_inputs[0], dydx);
325  grad_outputs->push_back(dx);
326  return scope.status();
327}
328REGISTER_GRADIENT_OP("Acos", AcosGrad);
329
330Status TanGrad(const Scope& scope, const Operation& op,
331               const std::vector<Output>& grad_inputs,
332               std::vector<Output>* grad_outputs) {
333  // y = tan(x)
334  // dy/dx = sec(x)^2 = 1 / cos(x)^2
335  auto dydx = Square(scope, Reciprocal(scope, Cos(scope, op.input(0))));
336  // grad(x) = grad(y) * conj(dy/dx)
337  auto dx = Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx));
338  grad_outputs->push_back(dx);
339  return scope.status();
340}
341REGISTER_GRADIENT_OP("Tan", TanGrad);
342
343Status AtanGrad(const Scope& scope, const Operation& op,
344                const std::vector<Output>& grad_inputs,
345                std::vector<Output>* grad_outputs) {
346  // y = arctan(x)
347  // dy/dx = 1 / (1 + x^2)
348  // dx = dy * (1 / (1 + x^2)
349  auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
350  auto dydx = Reciprocal(scope, Add(scope, one, Square(scope, op.input(0))));
351  auto dx = Mul(scope, grad_inputs[0], dydx);
352  grad_outputs->push_back(dx);
353  return scope.status();
354}
355REGISTER_GRADIENT_OP("Atan", AtanGrad);
356
357// BinaryGradCommon handles the setup for binary ops that broadcast
358// their inputs.
359Status BinaryGradCommon(const Scope& scope, const Operation& op,
360                        std::vector<Output>* grad_outputs, const Output& gx_1,
361                        const Output& gx_2) {
362  auto sx_1 = Shape(scope, op.input(0));
363  auto sx_2 = Shape(scope, op.input(1));
364  auto rx = internal::BroadcastGradientArgs(scope, sx_1, sx_2);
365  auto dx_1 = Reshape(scope, Sum(scope, gx_1, rx.r0), sx_1);
366  auto dx_2 = Reshape(scope, Sum(scope, gx_2, rx.r1), sx_2);
367  grad_outputs->push_back(dx_1);
368  grad_outputs->push_back(dx_2);
369  return scope.status();
370}
371
372Status AddGrad(const Scope& scope, const Operation& op,
373               const std::vector<Output>& grad_inputs,
374               std::vector<Output>* grad_outputs) {
375  // y = x_1 + x_2
376  // dy/dx_1 = dy/dx_2 = 1
377  auto gx_1 = Identity(scope, grad_inputs[0]);
378  auto gx_2 = Identity(scope, grad_inputs[0]);
379  return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
380}
381REGISTER_GRADIENT_OP("Add", AddGrad);
382
383Status SubGrad(const Scope& scope, const Operation& op,
384               const std::vector<Output>& grad_inputs,
385               std::vector<Output>* grad_outputs) {
386  // y = x_1 - x_2
387  // dy/dx_1 = 1
388  // dy/dx_2 = -1
389  auto gx_1 = Identity(scope, grad_inputs[0]);
390  auto gx_2 = Neg(scope, grad_inputs[0]);
391  return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
392}
393REGISTER_GRADIENT_OP("Sub", SubGrad);
394
395Status MulGrad(const Scope& scope, const Operation& op,
396               const std::vector<Output>& grad_inputs,
397               std::vector<Output>* grad_outputs) {
398  auto x_1 = ConjugateHelper(scope, op.input(0));
399  auto x_2 = ConjugateHelper(scope, op.input(1));
400  // y = x_1 * x_2
401  // dy/dx_1 = x_2
402  // dy/dx_2 = x_1
403  auto gx_1 = Mul(scope, grad_inputs[0], x_2);
404  auto gx_2 = Mul(scope, grad_inputs[0], x_1);
405  return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
406}
407REGISTER_GRADIENT_OP("Mul", MulGrad);
408
409Status DivGrad(const Scope& scope, const Operation& op,
410               const std::vector<Output>& grad_inputs,
411               std::vector<Output>* grad_outputs) {
412  auto x_1 = ConjugateHelper(scope, op.input(0));
413  auto x_2 = ConjugateHelper(scope, op.input(1));
414  // y = x_1 / x_2
415  // dy/dx_1 = 1/x_2
416  // dy/dx_2 = -x_1/x_2^2
417  auto gx_1 = Div(scope, grad_inputs[0], x_2);
418  auto gx_2 = Mul(scope, grad_inputs[0],
419                  Div(scope, Div(scope, Neg(scope, x_1), x_2), x_2));
420  return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
421}
422REGISTER_GRADIENT_OP("Div", DivGrad);
423
424Status RealDivGrad(const Scope& scope, const Operation& op,
425                   const std::vector<Output>& grad_inputs,
426                   std::vector<Output>* grad_outputs) {
427  auto x_1 = ConjugateHelper(scope, op.input(0));
428  auto x_2 = ConjugateHelper(scope, op.input(1));
429  // y = x_1 / x_2
430  // dy/dx_1 = 1/x_2
431  // dy/dx_2 = -x_1/x_2^2
432  auto gx_1 = RealDiv(scope, grad_inputs[0], x_2);
433  auto gx_2 = Mul(scope, grad_inputs[0],
434                  RealDiv(scope, RealDiv(scope, Neg(scope, x_1), x_2), x_2));
435  return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
436}
437REGISTER_GRADIENT_OP("RealDiv", RealDivGrad);
438
439Status SquaredDifferenceGrad(const Scope& scope, const Operation& op,
440                             const std::vector<Output>& grad_inputs,
441                             std::vector<Output>* grad_outputs) {
442  auto x_1 = ConjugateHelper(scope, op.input(0));
443  auto x_2 = ConjugateHelper(scope, op.input(1));
444  // y = (x_1 - x_2)^2
445  // dy/dx_1 = 2 * (x_1 - x_2)
446  // dy/dx_2 = -2 * (x_1 - x_2)
447  auto two = Cast(scope, Const(scope, 2), grad_inputs[0].type());
448  auto gx_1 = Mul(scope, grad_inputs[0], Mul(scope, two, Sub(scope, x_1, x_2)));
449  auto gx_2 = Neg(scope, gx_1);
450  return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
451}
452REGISTER_GRADIENT_OP("SquaredDifference", SquaredDifferenceGrad);
453
454Status AddNGrad(const Scope& scope, const Operation& op,
455                const std::vector<Output>& grad_inputs,
456                std::vector<Output>* grad_outputs) {
457  // AddN doesn't support broadcasting, so all the inputs must be the
458  // same shape.
459  // Note:
460  // dy/dx_k = d(x_1 + x_2 + ... + x_n)/dx_k = 1 for all x_k
461  // hence dx_k = dy for all x_k
462  // So the gradient for AddN just transfers the incoming gradient to
463  // all outgoing gradients.
464  auto incoming = Identity(scope, grad_inputs[0]);
465  for (int32 i = 0; i < op.num_inputs(); ++i) {
466    grad_outputs->push_back(incoming);
467  }
468  return scope.status();
469}
470REGISTER_GRADIENT_OP("AddN", AddNGrad);
471
472// MaximumMinimumGradCommon adds shared ops to calculate gradients for
473// the binary Maximum and Minimum ops.
474Status MaximumMinimumGradCommon(const Scope& scope, const Operation& op,
475                                const std::vector<Output>& grad_inputs,
476                                std::vector<Output>* grad_outputs,
477                                const Output& comparator) {
478  // comparator is a boolean tensor, with
479  // y = x_1 at points where comparator is true, and x_2 otherwise
480  // Therefore
481  // dy/dx_1 = 1 where comparator is true, and 0 otherwise.
482  // dy/dx_2 = 0 where comparator is true, and 1 otherwise.
483  auto grad = grad_inputs[0];
484  auto zeros = ZerosLike(scope, grad);
485  auto gx_1 = Where3(scope, comparator, grad, zeros);
486  auto gx_2 = Where3(scope, LogicalNot(scope, comparator), grad, zeros);
487  return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
488}
489
490Status MaximumGrad(const Scope& scope, const Operation& op,
491                   const std::vector<Output>& grad_inputs,
492                   std::vector<Output>* grad_outputs) {
493  auto comparator = GreaterEqual(scope, op.input(0), op.input(1));
494  return MaximumMinimumGradCommon(scope, op, grad_inputs, grad_outputs,
495                                  comparator);
496}
497REGISTER_GRADIENT_OP("Maximum", MaximumGrad);
498
499Status MinimumGrad(const Scope& scope, const Operation& op,
500                   const std::vector<Output>& grad_inputs,
501                   std::vector<Output>* grad_outputs) {
502  auto comparator = LessEqual(scope, op.input(0), op.input(1));
503  return MaximumMinimumGradCommon(scope, op, grad_inputs, grad_outputs,
504                                  comparator);
505}
506REGISTER_GRADIENT_OP("Minimum", MinimumGrad);
507
508Status RealGrad(const Scope& scope, const Operation& op,
509                const std::vector<Output>& grad_inputs,
510                std::vector<Output>* grad_outputs) {
511  auto zero = Cast(scope, Const(scope, 0.0), op.output(0).type());
512  auto dx = Complex(scope, grad_inputs[0], zero);
513  grad_outputs->push_back(dx);
514  return scope.status();
515}
516REGISTER_GRADIENT_OP("Real", RealGrad);
517
518Status ImagGrad(const Scope& scope, const Operation& op,
519                const std::vector<Output>& grad_inputs,
520                std::vector<Output>* grad_outputs) {
521  auto zero = Cast(scope, Const(scope, 0.0), op.output(0).type());
522  auto dx = Complex(scope, zero, grad_inputs[0]);
523  grad_outputs->push_back(dx);
524  return scope.status();
525}
526REGISTER_GRADIENT_OP("Imag", ImagGrad);
527
528Status AngleGrad(const Scope& scope, const Operation& op,
529                 const std::vector<Output>& grad_inputs,
530                 std::vector<Output>* grad_outputs) {
531  // y = Angle(x)
532  // dx = -dy / (Im(x) + iRe(x)) = -dy * z
533  auto re = Real(scope, op.input(0));
534  auto im = Imag(scope, op.input(0));
535  auto z_inv = Reciprocal(scope, Complex(scope, im, re));
536  auto zero = Cast(scope, Const(scope, 0), grad_inputs[0].type());
537  auto grad = Complex(scope, grad_inputs[0], zero);
538  auto dx = Neg(scope, Mul(scope, grad, z_inv));
539  grad_outputs->push_back(dx);
540  return scope.status();
541}
542REGISTER_GRADIENT_OP("Angle", AngleGrad);
543
544Status ConjGrad(const Scope& scope, const Operation& op,
545                const std::vector<Output>& grad_inputs,
546                std::vector<Output>* grad_outputs) {
547  grad_outputs->push_back(Conj(scope, grad_inputs[0]));
548  return scope.status();
549}
550REGISTER_GRADIENT_OP("Conj", ConjGrad);
551
552// MatMulGrad helper function used to compute two MatMul operations
553// based on input matrix transposition combinations.
554Status MatMulGradHelper(const Scope& scope, const bool is_batch,
555                        const Output& x0, const bool adj_x0, const Output& x1,
556                        const bool adj_x1, const Output& y0, const bool adj_y0,
557                        const Output& y1, const bool adj_y1,
558                        std::vector<Output>* grad_outputs) {
559  if (is_batch == false) {
560    auto dx =
561        MatMul(scope, x0, x1, MatMul::TransposeA(adj_x0).TransposeB(adj_x1));
562    grad_outputs->push_back(dx);
563    auto dy =
564        MatMul(scope, y0, y1, MatMul::TransposeA(adj_y0).TransposeB(adj_y1));
565    grad_outputs->push_back(dy);
566  } else {
567    auto dx =
568        BatchMatMul(scope, x0, x1, BatchMatMul::AdjX(adj_x0).AdjY(adj_x1));
569    grad_outputs->push_back(dx);
570    auto dy =
571        BatchMatMul(scope, y0, y1, BatchMatMul::AdjX(adj_y0).AdjY(adj_y1));
572    grad_outputs->push_back(dy);
573  }
574  return scope.status();
575}
576
577// MatMulGrad common used to read and check node attr state, and determine
578// proper MatMul products for gradients based on input matrix transposition
579// combinations.
580// TODO(andydavis) Re-use this function for BatchMatMulGrad.
581Status MatMulGradCommon(const Scope& scope, const Operation& op,
582                        const bool is_batch,
583                        const std::vector<Output>& grad_inputs,
584                        const string& attr_adj_x, const string& attr_adj_y,
585                        std::vector<Output>* grad_outputs) {
586  DataType dtype;
587  TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->attrs(), "T", &dtype));
588  if (dtype == DT_COMPLEX64 || dtype == DT_COMPLEX128) {
589    return errors::Unimplemented(
590        "MatMul gradient for complex data type is not supported yet.");
591  }
592
593  bool ta;
594  bool tb;
595  TF_RETURN_IF_ERROR(
596      GetNodeAttr(op.output(0).node()->attrs(), attr_adj_x, &ta));
597  TF_RETURN_IF_ERROR(
598      GetNodeAttr(op.output(0).node()->attrs(), attr_adj_y, &tb));
599
600  if (!ta && !tb) {
601    return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, op.input(1),
602                            true, op.input(0), true, grad_inputs[0], false,
603                            grad_outputs);
604  } else if (!ta && tb) {
605    return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, op.input(1),
606                            false, grad_inputs[0], true, op.input(0), false,
607                            grad_outputs);
608  } else if (ta && !tb) {
609    return MatMulGradHelper(scope, is_batch, op.input(1), false, grad_inputs[0],
610                            true, op.input(0), false, grad_inputs[0], false,
611                            grad_outputs);
612  }
613  return MatMulGradHelper(scope, is_batch, op.input(1), true, grad_inputs[0],
614                          true, grad_inputs[0], true, op.input(0), true,
615                          grad_outputs);
616}
617
618Status MatMulGrad(const Scope& scope, const Operation& op,
619                  const std::vector<Output>& grad_inputs,
620                  std::vector<Output>* grad_outputs) {
621  return MatMulGradCommon(scope, op, false, grad_inputs, "transpose_a",
622                          "transpose_b", grad_outputs);
623}
624REGISTER_GRADIENT_OP("MatMul", MatMulGrad);
625
626Status BatchMatMulGrad(const Scope& scope, const Operation& op,
627                       const std::vector<Output>& grad_inputs,
628                       std::vector<Output>* grad_outputs) {
629  return MatMulGradCommon(scope, op, true, grad_inputs, "adj_x", "adj_y",
630                          grad_outputs);
631}
632REGISTER_GRADIENT_OP("BatchMatMul", BatchMatMulGrad);
633
634}  // anonymous namespace
635}  // namespace ops
636}  // namespace tensorflow
637