math_grad.cc revision 6e3e7d18f42cb4237ce6dbe2ffd0f9f158c36daf
1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/cc/ops/array_ops_internal.h"
17#include "tensorflow/cc/ops/standard_ops.h"
18
19#include "tensorflow/cc/framework/grad_op_registry.h"
20
21namespace tensorflow {
22namespace ops {
23namespace {
24
25// Logical operations have no gradients.
26REGISTER_NO_GRADIENT_OP("Less");
27REGISTER_NO_GRADIENT_OP("LessEqual");
28REGISTER_NO_GRADIENT_OP("Greater");
29REGISTER_NO_GRADIENT_OP("GreaterEqual");
30REGISTER_NO_GRADIENT_OP("Equal");
31REGISTER_NO_GRADIENT_OP("ApproximateEqual");
32REGISTER_NO_GRADIENT_OP("NotEqual");
33REGISTER_NO_GRADIENT_OP("LogicalAnd");
34REGISTER_NO_GRADIENT_OP("LogicalOr");
35REGISTER_NO_GRADIENT_OP("LogicalNot");
36
37// Conjugate helper function returns the conjugate of an Output if it
38// is complex valued.
39Output ConjugateHelper(const Scope& scope, const Output& out) {
40  DataType dtype = out.type();
41  if (dtype == DT_COMPLEX64 || dtype == DT_COMPLEX128) {
42    return Conj(scope, out);
43  } else {
44    return out;
45  }
46}
47
48// TODO(andydavis) Add control dependencies to gradient functions (as needed).
49
50Status AbsGrad(const Scope& scope, const Operation& op,
51               const std::vector<Output>& grad_inputs,
52               std::vector<Output>* grad_outputs) {
53  // dx = dy * sign(x)
54  grad_outputs->push_back(Mul(scope, grad_inputs[0], Sign(scope, op.input(0))));
55  return scope.status();
56}
57REGISTER_GRADIENT_OP("Abs", AbsGrad);
58
59Status NegGrad(const Scope& scope, const Operation& op,
60               const std::vector<Output>& grad_inputs,
61               std::vector<Output>* grad_outputs) {
62  // dx = -dy;
63  grad_outputs->push_back(Neg(scope, grad_inputs[0]));
64  return scope.status();
65}
66REGISTER_GRADIENT_OP("Neg", NegGrad);
67
68Status InvGrad(const Scope& scope, const Operation& op,
69               const std::vector<Output>& grad_inputs,
70               std::vector<Output>* grad_outputs) {
71  // dy/dx = -1/x^2 = -y^2
72  auto dydx = Neg(scope, Square(scope, op.output(0)));
73  // grad(x) = grad(y) * conj(dy/dx)
74  grad_outputs->push_back(
75      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
76  return scope.status();
77}
78REGISTER_GRADIENT_OP("Inv", InvGrad);
79REGISTER_GRADIENT_OP("Reciprocal", InvGrad);
80
81Status SquareGrad(const Scope& scope, const Operation& op,
82                  const std::vector<Output>& grad_inputs,
83                  std::vector<Output>* grad_outputs) {
84  // dy/dx = (2 * x)
85  auto two = Cast(scope, Const(scope, 2), op.input(0).type());
86  auto dydx = Mul(scope, two, op.input(0));
87  // grad(x) = grad(y) * conj(dy/dx)
88  grad_outputs->push_back(
89      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
90  return scope.status();
91}
92REGISTER_GRADIENT_OP("Square", SquareGrad);
93
94Status SqrtGrad(const Scope& scope, const Operation& op,
95                const std::vector<Output>& grad_inputs,
96                std::vector<Output>* grad_outputs) {
97  // y = sqrt(x)
98  // dy/dx =  0.5 * (1 / sqrt(x)) = 0.5 * (1 / y)
99  auto y_inv = Reciprocal(scope, op.output(0));
100  auto half = Cast(scope, Const(scope, 0.5), op.input(0).type());
101  auto dydx = Mul(scope, half, y_inv);
102  // grad(x) = grad(y) * conj(dy/dx)
103  grad_outputs->push_back(
104      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
105  return scope.status();
106}
107REGISTER_GRADIENT_OP("Sqrt", SqrtGrad);
108
109Status RsqrtGrad(const Scope& scope, const Operation& op,
110                 const std::vector<Output>& grad_inputs,
111                 std::vector<Output>* grad_outputs) {
112  // y = 1/x^1/2 = x^-1/2
113  // dy/dx = -1/2 * x^-3/2 = -1/2 * x^-1/2 * x^-1 = -1/2 * y * x^-1
114  auto x_inv = Reciprocal(scope, op.input(0));
115  auto y = op.output(0);
116  auto neghalf = Cast(scope, Const(scope, -0.5), op.input(0).type());
117  auto a = Mul(scope, neghalf, x_inv);
118  auto dydx = Mul(scope, a, y);
119  // grad(x) = grad(y) * conj(dy/dx)
120  grad_outputs->push_back(
121      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
122  return scope.status();
123}
124REGISTER_GRADIENT_OP("Rsqrt", RsqrtGrad);
125
126Status ExpGrad(const Scope& scope, const Operation& op,
127               const std::vector<Output>& grad_inputs,
128               std::vector<Output>* grad_outputs) {
129  // dy/dx = exp(x) = y
130  // grad(x) = grad(y) * conj(dy/dx)
131  //         = grad(y) * conj(y)
132  grad_outputs->push_back(
133      Mul(scope, grad_inputs[0], ConjugateHelper(scope, op.output(0))));
134  return scope.status();
135}
136REGISTER_GRADIENT_OP("Exp", ExpGrad);
137
138Status Expm1Grad(const Scope& scope, const Operation& op,
139                 const std::vector<Output>& grad_inputs,
140                 std::vector<Output>* grad_outputs) {
141  // y = expm1(x)
142  // dy/dx = exp(x)
143  auto dydx = Exp(scope, op.input(0));
144  // grad(x) = grad(y) * conj(dy/dx)
145  grad_outputs->push_back(
146      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
147  return scope.status();
148}
149REGISTER_GRADIENT_OP("Expm1", Expm1Grad);
150
151Status LogGrad(const Scope& scope, const Operation& op,
152               const std::vector<Output>& grad_inputs,
153               std::vector<Output>* grad_outputs) {
154  // y = log(x)
155  // dy/dx = 1 / x
156  auto dydx = Reciprocal(scope, op.input(0));
157  // grad(x) = grad(y) * conj(dy/dx)
158  grad_outputs->push_back(
159      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
160  return scope.status();
161}
162REGISTER_GRADIENT_OP("Log", LogGrad);
163
164Status Log1pGrad(const Scope& scope, const Operation& op,
165                 const std::vector<Output>& grad_inputs,
166                 std::vector<Output>* grad_outputs) {
167  // y = log1p(x)
168  // dy/dx = 1 / (1 + x)
169  auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
170  auto dydx = Reciprocal(scope, Add(scope, one, op.input(0)));
171  // grad(x) = grad(y) * conj(dy/dx)
172  grad_outputs->push_back(
173      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
174  return scope.status();
175}
176REGISTER_GRADIENT_OP("Log1p", Log1pGrad);
177
178Status SinhGrad(const Scope& scope, const Operation& op,
179                const std::vector<Output>& grad_inputs,
180                std::vector<Output>* grad_outputs) {
181  // y = sinh(x)
182  // dy/dx = cosh(x)
183  auto dydx = Cosh(scope, op.input(0));
184  // grad(x) = grad(y) * conj(dy/dx)
185  grad_outputs->push_back(
186      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
187  return scope.status();
188}
189REGISTER_GRADIENT_OP("Sinh", SinhGrad);
190
191Status CoshGrad(const Scope& scope, const Operation& op,
192                const std::vector<Output>& grad_inputs,
193                std::vector<Output>* grad_outputs) {
194  // y = cosh(x)
195  // dy/dx = sinh(x)
196  auto dydx = Sinh(scope, op.input(0));
197  // grad(x) = grad(y) * conj(dy/dx)
198  grad_outputs->push_back(
199      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
200  return scope.status();
201}
202REGISTER_GRADIENT_OP("Cosh", CoshGrad);
203
204Status TanhGrad(const Scope& scope, const Operation& op,
205                const std::vector<Output>& grad_inputs,
206                std::vector<Output>* grad_outputs) {
207  // y = tanh(x)
208  // dy/dx = 1 - (tanh(x))^2 = 1 - y^2
209  auto y2 = Square(scope, op.output(0));
210  auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
211  auto dydx = Sub(scope, one, y2);
212  // grad(x) = grad(y) * conj(dy/dx)
213  grad_outputs->push_back(
214      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
215  return scope.status();
216}
217REGISTER_GRADIENT_OP("Tanh", TanhGrad);
218
219Status AsinhGrad(const Scope& scope, const Operation& op,
220                 const std::vector<Output>& grad_inputs,
221                 std::vector<Output>* grad_outputs) {
222  // y = asinh(x)
223  // dy/dx = 1 / cosh(y)
224  auto dydx = Reciprocal(scope, Cosh(scope, op.output(0)));
225  // grad(x) = grad(y) * conj(dy/dx)
226  grad_outputs->push_back(
227      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
228  return scope.status();
229}
230REGISTER_GRADIENT_OP("Asinh", AsinhGrad);
231
232Status AcoshGrad(const Scope& scope, const Operation& op,
233                 const std::vector<Output>& grad_inputs,
234                 std::vector<Output>* grad_outputs) {
235  // y = acosh(x)
236  // dy/dx = 1 / sinh(y)
237  auto dydx = Reciprocal(scope, Sinh(scope, op.output(0)));
238  // grad(x) = grad(y) * conj(dy/dx)
239  grad_outputs->push_back(
240      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
241  return scope.status();
242}
243REGISTER_GRADIENT_OP("Acosh", AcoshGrad);
244
245Status AtanhGrad(const Scope& scope, const Operation& op,
246                 const std::vector<Output>& grad_inputs,
247                 std::vector<Output>* grad_outputs) {
248  // y = atanh(x)
249  // dy/dx = 1 / (1 - x^2)
250  auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
251  auto dydx = Reciprocal(scope, Sub(scope, one, Square(scope, op.input(0))));
252  // grad(x) = grad(y) * conj(dy/dx)
253  grad_outputs->push_back(
254      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
255  return scope.status();
256}
257REGISTER_GRADIENT_OP("Atanh", AtanhGrad);
258
259Status SigmoidGrad(const Scope& scope, const Operation& op,
260                   const std::vector<Output>& grad_inputs,
261                   std::vector<Output>* grad_outputs) {
262  // y = 1 / (1 + exp(-x))
263  // dy/dx = y * (1 - y)
264  auto y = op.output(0);
265  auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
266  auto dydx = Mul(scope, y, Sub(scope, one, y));
267  // dx = dy * y * (1 - y)
268  // grad(x) = grad(y) * conj(dy/dx)
269  grad_outputs->push_back(
270      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
271  return scope.status();
272}
273REGISTER_GRADIENT_OP("Sigmoid", SigmoidGrad);
274
275Status SignGrad(const Scope& scope, const Operation& op,
276                const std::vector<Output>& grad_inputs,
277                std::vector<Output>* grad_outputs) {
278  auto shape = Shape(scope, op.input(0));
279  auto zero = Cast(scope, Const(scope, 0.0), op.input(0).type());
280  auto dx = Fill(scope, shape, zero);
281  grad_outputs->push_back(dx);
282  return scope.status();
283}
284REGISTER_GRADIENT_OP("Sign", SignGrad);
285
286Status SinGrad(const Scope& scope, const Operation& op,
287               const std::vector<Output>& grad_inputs,
288               std::vector<Output>* grad_outputs) {
289  // y = sin(x)
290  // dy/dx = cos(x)
291  auto dydx = Cos(scope, op.input(0));
292  // grad(x) = grad(y) * conj(dy/dx)
293  grad_outputs->push_back(
294      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
295  return scope.status();
296}
297REGISTER_GRADIENT_OP("Sin", SinGrad);
298
299Status CosGrad(const Scope& scope, const Operation& op,
300               const std::vector<Output>& grad_inputs,
301               std::vector<Output>* grad_outputs) {
302  // y = cos(x)
303  // dy/dx = -sin(x)
304  auto dydx = Neg(scope, Sin(scope, op.input(0)));
305  // grad(x) = grad(y) * conj(dy/dx)
306  grad_outputs->push_back(
307      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
308  return scope.status();
309}
310REGISTER_GRADIENT_OP("Cos", CosGrad);
311
312Status AsinGrad(const Scope& scope, const Operation& op,
313                const std::vector<Output>& grad_inputs,
314                std::vector<Output>* grad_outputs) {
315  // y = asin(x)
316  // dy/dx = 1 / sqrt(1 - x^2)
317  auto x2 = Square(scope, op.input(0));
318  auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
319  auto dydx = Reciprocal(scope, Sqrt(scope, Sub(scope, one, x2)));
320  // grad(x) = grad(y) * conj(dy/dx)
321  auto dx = Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx));
322  grad_outputs->push_back(dx);
323  return scope.status();
324}
325REGISTER_GRADIENT_OP("Asin", AsinGrad);
326
327Status AcosGrad(const Scope& scope, const Operation& op,
328                const std::vector<Output>& grad_inputs,
329                std::vector<Output>* grad_outputs) {
330  // y = acos(x)
331  // dy/dx = - 1 / (1 - x * x)^1/2
332  // dx = dy * (- 1 / (1 - x * x)^1/2)
333  auto x2 = Square(scope, op.input(0));
334  auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
335  auto dydx = Neg(scope, Reciprocal(scope, Sqrt(scope, Sub(scope, one, x2))));
336  auto dx = Mul(scope, grad_inputs[0], dydx);
337  grad_outputs->push_back(dx);
338  return scope.status();
339}
340REGISTER_GRADIENT_OP("Acos", AcosGrad);
341
342Status TanGrad(const Scope& scope, const Operation& op,
343               const std::vector<Output>& grad_inputs,
344               std::vector<Output>* grad_outputs) {
345  // y = tan(x)
346  // dy/dx = sec(x)^2 = 1 / cos(x)^2
347  auto dydx = Square(scope, Reciprocal(scope, Cos(scope, op.input(0))));
348  // grad(x) = grad(y) * conj(dy/dx)
349  auto dx = Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx));
350  grad_outputs->push_back(dx);
351  return scope.status();
352}
353REGISTER_GRADIENT_OP("Tan", TanGrad);
354
355Status AtanGrad(const Scope& scope, const Operation& op,
356                const std::vector<Output>& grad_inputs,
357                std::vector<Output>* grad_outputs) {
358  // y = arctan(x)
359  // dy/dx = 1 / (1 + x^2)
360  // dx = dy * (1 / (1 + x^2)
361  auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
362  auto dydx = Reciprocal(scope, Add(scope, one, Square(scope, op.input(0))));
363  auto dx = Mul(scope, grad_inputs[0], dydx);
364  grad_outputs->push_back(dx);
365  return scope.status();
366}
367REGISTER_GRADIENT_OP("Atan", AtanGrad);
368
369// BinaryGradCommon handles the setup for binary ops that broadcast
370// their inputs.
371Status BinaryGradCommon(const Scope& scope, const Operation& op,
372                        std::vector<Output>* grad_outputs, const Output& gx_1,
373                        const Output& gx_2) {
374  auto sx_1 = Shape(scope, op.input(0));
375  auto sx_2 = Shape(scope, op.input(1));
376  auto rx = ops::internal::BroadcastGradientArgs(scope, sx_1, sx_2);
377  auto dx_1 = Reshape(scope, Sum(scope, gx_1, rx.r0), sx_1);
378  auto dx_2 = Reshape(scope, Sum(scope, gx_2, rx.r1), sx_2);
379  grad_outputs->push_back(dx_1);
380  grad_outputs->push_back(dx_2);
381  return scope.status();
382}
383
384Status AddGrad(const Scope& scope, const Operation& op,
385               const std::vector<Output>& grad_inputs,
386               std::vector<Output>* grad_outputs) {
387  // y = x_1 + x_2
388  // dy/dx_1 = dy/dx_2 = 1
389  auto gx_1 = Identity(scope, grad_inputs[0]);
390  auto gx_2 = Identity(scope, grad_inputs[0]);
391  return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
392}
393REGISTER_GRADIENT_OP("Add", AddGrad);
394
395Status SubGrad(const Scope& scope, const Operation& op,
396               const std::vector<Output>& grad_inputs,
397               std::vector<Output>* grad_outputs) {
398  // y = x_1 - x_2
399  // dy/dx_1 = 1
400  // dy/dx_2 = -1
401  auto gx_1 = Identity(scope, grad_inputs[0]);
402  auto gx_2 = Neg(scope, grad_inputs[0]);
403  return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
404}
405REGISTER_GRADIENT_OP("Sub", SubGrad);
406
407Status MulGrad(const Scope& scope, const Operation& op,
408               const std::vector<Output>& grad_inputs,
409               std::vector<Output>* grad_outputs) {
410  auto x_1 = ConjugateHelper(scope, op.input(0));
411  auto x_2 = ConjugateHelper(scope, op.input(1));
412  // y = x_1 * x_2
413  // dy/dx_1 = x_2
414  // dy/dx_2 = x_1
415  auto gx_1 = Mul(scope, grad_inputs[0], x_2);
416  auto gx_2 = Mul(scope, grad_inputs[0], x_1);
417  return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
418}
419REGISTER_GRADIENT_OP("Mul", MulGrad);
420
421Status DivGrad(const Scope& scope, const Operation& op,
422               const std::vector<Output>& grad_inputs,
423               std::vector<Output>* grad_outputs) {
424  auto x_1 = ConjugateHelper(scope, op.input(0));
425  auto x_2 = ConjugateHelper(scope, op.input(1));
426  // y = x_1 / x_2
427  // dy/dx_1 = 1/x_2
428  // dy/dx_2 = -x_1/x_2^2
429  auto gx_1 = Div(scope, grad_inputs[0], x_2);
430  auto gx_2 = Mul(scope, grad_inputs[0],
431                  Div(scope, Div(scope, Neg(scope, x_1), x_2), x_2));
432  return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
433}
434REGISTER_GRADIENT_OP("Div", DivGrad);
435
436Status RealDivGrad(const Scope& scope, const Operation& op,
437                   const std::vector<Output>& grad_inputs,
438                   std::vector<Output>* grad_outputs) {
439  auto x_1 = ConjugateHelper(scope, op.input(0));
440  auto x_2 = ConjugateHelper(scope, op.input(1));
441  // y = x_1 / x_2
442  // dy/dx_1 = 1/x_2
443  // dy/dx_2 = -x_1/x_2^2
444  auto gx_1 = RealDiv(scope, grad_inputs[0], x_2);
445  auto gx_2 = Mul(scope, grad_inputs[0],
446                  RealDiv(scope, RealDiv(scope, Neg(scope, x_1), x_2), x_2));
447  return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
448}
449REGISTER_GRADIENT_OP("RealDiv", RealDivGrad);
450
451Status SquaredDifferenceGrad(const Scope& scope, const Operation& op,
452                             const std::vector<Output>& grad_inputs,
453                             std::vector<Output>* grad_outputs) {
454  auto x_1 = ConjugateHelper(scope, op.input(0));
455  auto x_2 = ConjugateHelper(scope, op.input(1));
456  // y = (x_1 - x_2)^2
457  // dy/dx_1 = 2 * (x_1 - x_2)
458  // dy/dx_2 = -2 * (x_1 - x_2)
459  auto two = Cast(scope, Const(scope, 2), grad_inputs[0].type());
460  auto gx_1 = Mul(scope, grad_inputs[0], Mul(scope, two, Sub(scope, x_1, x_2)));
461  auto gx_2 = Neg(scope, gx_1);
462  return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
463}
464REGISTER_GRADIENT_OP("SquaredDifference", SquaredDifferenceGrad);
465
466Status AddNGrad(const Scope& scope, const Operation& op,
467                const std::vector<Output>& grad_inputs,
468                std::vector<Output>* grad_outputs) {
469  // AddN doesn't support broadcasting, so all the inputs must be the
470  // same shape.
471  // Note:
472  // dy/dx_k = d(x_1 + x_2 + ... + x_n)/dx_k = 1 for all x_k
473  // hence dx_k = dy for all x_k
474  // So the gradient for AddN just transfers the incoming gradient to
475  // all outgoing gradients.
476  auto incoming = Identity(scope, grad_inputs[0]);
477  for (int32 i = 0; i < op.num_inputs(); ++i) {
478    grad_outputs->push_back(incoming);
479  }
480  return scope.status();
481}
482REGISTER_GRADIENT_OP("AddN", AddNGrad);
483
484// MaximumMinimumGradCommon adds shared ops to calculate gradients for
485// the binary Maximum and Minimum ops.
486Status MaximumMinimumGradCommon(const Scope& scope, const Operation& op,
487                                const std::vector<Output>& grad_inputs,
488                                std::vector<Output>* grad_outputs,
489                                const Output& comparator) {
490  // comparator is a boolean tensor, with
491  // y = x_1 at points where comparator is true, and x_2 otherwise
492  // Therefore
493  // dy/dx_1 = 1 where comparator is true, and 0 otherwise.
494  // dy/dx_2 = 0 where comparator is true, and 1 otherwise.
495  auto grad = grad_inputs[0];
496  auto zeros = ZerosLike(scope, grad);
497  auto gx_1 = Where3(scope, comparator, grad, zeros);
498  auto gx_2 = Where3(scope, LogicalNot(scope, comparator), grad, zeros);
499  return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
500}
501
502Status MaximumGrad(const Scope& scope, const Operation& op,
503                   const std::vector<Output>& grad_inputs,
504                   std::vector<Output>* grad_outputs) {
505  auto comparator = GreaterEqual(scope, op.input(0), op.input(1));
506  return MaximumMinimumGradCommon(scope, op, grad_inputs, grad_outputs,
507                                  comparator);
508}
509REGISTER_GRADIENT_OP("Maximum", MaximumGrad);
510
511Status MinimumGrad(const Scope& scope, const Operation& op,
512                   const std::vector<Output>& grad_inputs,
513                   std::vector<Output>* grad_outputs) {
514  auto comparator = LessEqual(scope, op.input(0), op.input(1));
515  return MaximumMinimumGradCommon(scope, op, grad_inputs, grad_outputs,
516                                  comparator);
517}
518REGISTER_GRADIENT_OP("Minimum", MinimumGrad);
519
520Status RealGrad(const Scope& scope, const Operation& op,
521                const std::vector<Output>& grad_inputs,
522                std::vector<Output>* grad_outputs) {
523  auto zero = Cast(scope, Const(scope, 0.0), op.output(0).type());
524  auto dx = Complex(scope, grad_inputs[0], zero);
525  grad_outputs->push_back(dx);
526  return scope.status();
527}
528REGISTER_GRADIENT_OP("Real", RealGrad);
529
530Status ImagGrad(const Scope& scope, const Operation& op,
531                const std::vector<Output>& grad_inputs,
532                std::vector<Output>* grad_outputs) {
533  auto zero = Cast(scope, Const(scope, 0.0), op.output(0).type());
534  auto dx = Complex(scope, zero, grad_inputs[0]);
535  grad_outputs->push_back(dx);
536  return scope.status();
537}
538REGISTER_GRADIENT_OP("Imag", ImagGrad);
539
540Status AngleGrad(const Scope& scope, const Operation& op,
541                 const std::vector<Output>& grad_inputs,
542                 std::vector<Output>* grad_outputs) {
543  // y = Angle(x)
544  // dx = -dy / (Im(x) + iRe(x)) = -dy * z
545  auto re = Real(scope, op.input(0));
546  auto im = Imag(scope, op.input(0));
547  auto z_inv = Reciprocal(scope, Complex(scope, im, re));
548  auto zero = Cast(scope, Const(scope, 0), grad_inputs[0].type());
549  auto grad = Complex(scope, grad_inputs[0], zero);
550  auto dx = Neg(scope, Mul(scope, grad, z_inv));
551  grad_outputs->push_back(dx);
552  return scope.status();
553}
554REGISTER_GRADIENT_OP("Angle", AngleGrad);
555
556Status ConjGrad(const Scope& scope, const Operation& op,
557                const std::vector<Output>& grad_inputs,
558                std::vector<Output>* grad_outputs) {
559  grad_outputs->push_back(Conj(scope, grad_inputs[0]));
560  return scope.status();
561}
562REGISTER_GRADIENT_OP("Conj", ConjGrad);
563
564// MatMulGrad helper function used to compute two MatMul operations
565// based on input matrix transposition combinations.
566Status MatMulGradHelper(const Scope& scope, const bool is_batch,
567                        const Output& x0, const bool adj_x0, const Output& x1,
568                        const bool adj_x1, const Output& y0, const bool adj_y0,
569                        const Output& y1, const bool adj_y1,
570                        std::vector<Output>* grad_outputs) {
571  if (is_batch == false) {
572    auto dx =
573        MatMul(scope, x0, x1, MatMul::TransposeA(adj_x0).TransposeB(adj_x1));
574    grad_outputs->push_back(dx);
575    auto dy =
576        MatMul(scope, y0, y1, MatMul::TransposeA(adj_y0).TransposeB(adj_y1));
577    grad_outputs->push_back(dy);
578  } else {
579    auto dx =
580        BatchMatMul(scope, x0, x1, BatchMatMul::AdjX(adj_x0).AdjY(adj_x1));
581    grad_outputs->push_back(dx);
582    auto dy =
583        BatchMatMul(scope, y0, y1, BatchMatMul::AdjX(adj_y0).AdjY(adj_y1));
584    grad_outputs->push_back(dy);
585  }
586  return scope.status();
587}
588
589// MatMulGrad common used to read and check node attr state, and determine
590// proper MatMul products for gradients based on input matrix transposition
591// combinations.
592// TODO(andydavis) Re-use this function for BatchMatMulGrad.
593Status MatMulGradCommon(const Scope& scope, const Operation& op,
594                        const bool is_batch,
595                        const std::vector<Output>& grad_inputs,
596                        const string& attr_adj_x, const string& attr_adj_y,
597                        std::vector<Output>* grad_outputs) {
598  DataType dtype;
599  TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->attrs(), "T", &dtype));
600  if (dtype == DT_COMPLEX64 || dtype == DT_COMPLEX128) {
601    return errors::Unimplemented(
602        "MatMul gradient for complex data type is not supported yet.");
603  }
604
605  bool ta;
606  bool tb;
607  TF_RETURN_IF_ERROR(
608      GetNodeAttr(op.output(0).node()->attrs(), attr_adj_x, &ta));
609  TF_RETURN_IF_ERROR(
610      GetNodeAttr(op.output(0).node()->attrs(), attr_adj_y, &tb));
611
612  if (!ta && !tb) {
613    return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, op.input(1),
614                            true, op.input(0), true, grad_inputs[0], false,
615                            grad_outputs);
616  } else if (!ta && tb) {
617    return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, op.input(1),
618                            false, grad_inputs[0], true, op.input(0), false,
619                            grad_outputs);
620  } else if (ta && !tb) {
621    return MatMulGradHelper(scope, is_batch, op.input(1), false, grad_inputs[0],
622                            true, op.input(0), false, grad_inputs[0], false,
623                            grad_outputs);
624  }
625  return MatMulGradHelper(scope, is_batch, op.input(1), true, grad_inputs[0],
626                          true, grad_inputs[0], true, op.input(0), true,
627                          grad_outputs);
628}
629
630Status MatMulGrad(const Scope& scope, const Operation& op,
631                  const std::vector<Output>& grad_inputs,
632                  std::vector<Output>* grad_outputs) {
633  return MatMulGradCommon(scope, op, false, grad_inputs, "transpose_a",
634                          "transpose_b", grad_outputs);
635}
636REGISTER_GRADIENT_OP("MatMul", MatMulGrad);
637
638Status BatchMatMulGrad(const Scope& scope, const Operation& op,
639                       const std::vector<Output>& grad_inputs,
640                       std::vector<Output>* grad_outputs) {
641  return MatMulGradCommon(scope, op, true, grad_inputs, "adj_x", "adj_y",
642                          grad_outputs);
643}
644REGISTER_GRADIENT_OP("BatchMatMul", BatchMatMulGrad);
645
646}  // anonymous namespace
647}  // namespace ops
648}  // namespace tensorflow
649