math_grad.cc revision 16001fc526831c7a7f1a3814f517b01008df4c4c
1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/cc/ops/array_ops_internal.h"
17#include "tensorflow/cc/ops/standard_ops.h"
18
19#include "tensorflow/cc/framework/grad_op_registry.h"
20
21namespace tensorflow {
22namespace ops {
23namespace {
24
25// Logical operations have no gradients.
26REGISTER_NO_GRADIENT_OP("Less");
27REGISTER_NO_GRADIENT_OP("LessEqual");
28REGISTER_NO_GRADIENT_OP("Greater");
29REGISTER_NO_GRADIENT_OP("GreaterEqual");
30REGISTER_NO_GRADIENT_OP("Equal");
31REGISTER_NO_GRADIENT_OP("ApproximateEqual");
32REGISTER_NO_GRADIENT_OP("NotEqual");
33REGISTER_NO_GRADIENT_OP("LogicalAnd");
34REGISTER_NO_GRADIENT_OP("LogicalOr");
35REGISTER_NO_GRADIENT_OP("LogicalNot");
36
37// Conjugate helper function returns the conjugate of an Output if it
38// is complex valued.
39Output ConjugateHelper(const Scope& scope, const Output& out) {
40  DataType dtype = out.type();
41  if (dtype == DT_COMPLEX64 || dtype == DT_COMPLEX128) {
42    return Conj(scope, out);
43  } else {
44    return out;
45  }
46}
47
48// TODO(andydavis) Add control dependencies to gradient functions (as needed).
49
50Status AbsGrad(const Scope& scope, const Operation& op,
51               const std::vector<Output>& grad_inputs,
52               std::vector<Output>* grad_outputs) {
53  // dx = dy * sign(x)
54  grad_outputs->push_back(Mul(scope, grad_inputs[0], Sign(scope, op.input(0))));
55  return scope.status();
56}
57REGISTER_GRADIENT_OP("Abs", AbsGrad);
58
59Status NegGrad(const Scope& scope, const Operation& op,
60               const std::vector<Output>& grad_inputs,
61               std::vector<Output>* grad_outputs) {
62  // dx = -dy;
63  grad_outputs->push_back(Neg(scope, grad_inputs[0]));
64  return scope.status();
65}
66REGISTER_GRADIENT_OP("Neg", NegGrad);
67
68Status InvGrad(const Scope& scope, const Operation& op,
69               const std::vector<Output>& grad_inputs,
70               std::vector<Output>* grad_outputs) {
71  // dy/dx = -1/x^2 = -y^2
72  auto dydx = Neg(scope, Square(scope, op.output(0)));
73  // grad(x) = grad(y) * conj(dy/dx)
74  grad_outputs->push_back(
75      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
76  return scope.status();
77}
78REGISTER_GRADIENT_OP("Inv", InvGrad);
79REGISTER_GRADIENT_OP("Reciprocal", InvGrad);
80
81Status SquareGrad(const Scope& scope, const Operation& op,
82                  const std::vector<Output>& grad_inputs,
83                  std::vector<Output>* grad_outputs) {
84  // dy/dx = (2 * x)
85  auto two = Cast(scope, Const(scope, 2), op.input(0).type());
86  auto dydx = Mul(scope, two, op.input(0));
87  // grad(x) = grad(y) * conj(dy/dx)
88  grad_outputs->push_back(
89      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
90  return scope.status();
91}
92REGISTER_GRADIENT_OP("Square", SquareGrad);
93
94Status SqrtGrad(const Scope& scope, const Operation& op,
95                const std::vector<Output>& grad_inputs,
96                std::vector<Output>* grad_outputs) {
97  // y = sqrt(x)
98  // dy/dx =  0.5 * (1 / sqrt(x)) = 0.5 * (1 / y)
99  auto y_inv = Reciprocal(scope, op.output(0));
100  auto half = Cast(scope, Const(scope, 0.5), op.input(0).type());
101  auto dydx = Mul(scope, half, y_inv);
102  // grad(x) = grad(y) * conj(dy/dx)
103  grad_outputs->push_back(
104      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
105  return scope.status();
106}
107REGISTER_GRADIENT_OP("Sqrt", SqrtGrad);
108
109Status RsqrtGrad(const Scope& scope, const Operation& op,
110                 const std::vector<Output>& grad_inputs,
111                 std::vector<Output>* grad_outputs) {
112  // y = 1/x^1/2 = x^-1/2
113  // dy/dx = -1/2 * x^-3/2 = -1/2 * x^-1/2 * x^-1 = -1/2 * y * x^-1
114  auto x_inv = Reciprocal(scope, op.input(0));
115  auto y = op.output(0);
116  auto neghalf = Cast(scope, Const(scope, -0.5), op.input(0).type());
117  auto a = Mul(scope, neghalf, x_inv);
118  auto dydx = Mul(scope, a, y);
119  // grad(x) = grad(y) * conj(dy/dx)
120  grad_outputs->push_back(
121      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
122  return scope.status();
123}
124REGISTER_GRADIENT_OP("Rsqrt", RsqrtGrad);
125
126Status ExpGrad(const Scope& scope, const Operation& op,
127               const std::vector<Output>& grad_inputs,
128               std::vector<Output>* grad_outputs) {
129  // dy/dx = exp(x) = y
130  // grad(x) = grad(y) * conj(dy/dx)
131  //         = grad(y) * conj(y)
132  grad_outputs->push_back(
133      Mul(scope, grad_inputs[0], ConjugateHelper(scope, op.output(0))));
134  return scope.status();
135}
136REGISTER_GRADIENT_OP("Exp", ExpGrad);
137
138Status Expm1Grad(const Scope& scope, const Operation& op,
139                 const std::vector<Output>& grad_inputs,
140                 std::vector<Output>* grad_outputs) {
141  // y = expm1(x)
142  // dy/dx = exp(x)
143  auto dydx = Exp(scope, op.input(0));
144  // grad(x) = grad(y) * conj(dy/dx)
145  grad_outputs->push_back(
146      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
147  return scope.status();
148}
149REGISTER_GRADIENT_OP("Expm1", Expm1Grad);
150
151Status LogGrad(const Scope& scope, const Operation& op,
152               const std::vector<Output>& grad_inputs,
153               std::vector<Output>* grad_outputs) {
154  // y = log(x)
155  // dy/dx = 1 / x
156  auto dydx = Reciprocal(scope, op.input(0));
157  // grad(x) = grad(y) * conj(dy/dx)
158  grad_outputs->push_back(
159      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
160  return scope.status();
161}
162REGISTER_GRADIENT_OP("Log", LogGrad);
163
164Status Log1pGrad(const Scope& scope, const Operation& op,
165                 const std::vector<Output>& grad_inputs,
166                 std::vector<Output>* grad_outputs) {
167  // y = log1p(x)
168  // dy/dx = 1 / (1 + x)
169  auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
170  auto dydx = Reciprocal(scope, Add(scope, one, op.input(0)));
171  // grad(x) = grad(y) * conj(dy/dx)
172  grad_outputs->push_back(
173      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
174  return scope.status();
175}
176REGISTER_GRADIENT_OP("Log1p", Log1pGrad);
177
178Status SinhGrad(const Scope& scope, const Operation& op,
179                const std::vector<Output>& grad_inputs,
180                std::vector<Output>* grad_outputs) {
181  // y = sinh(x)
182  // dy/dx = cosh(x)
183  auto dydx = Cosh(scope, op.input(0));
184  // grad(x) = grad(y) * conj(dy/dx)
185  grad_outputs->push_back(
186      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
187  return scope.status();
188}
189REGISTER_GRADIENT_OP("Sinh", SinhGrad);
190
191Status CoshGrad(const Scope& scope, const Operation& op,
192                const std::vector<Output>& grad_inputs,
193                std::vector<Output>* grad_outputs) {
194  // y = cosh(x)
195  // dy/dx = sinh(x)
196  auto dydx = Sinh(scope, op.input(0));
197  // grad(x) = grad(y) * conj(dy/dx)
198  grad_outputs->push_back(
199      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
200  return scope.status();
201}
202REGISTER_GRADIENT_OP("Cosh", CoshGrad);
203
204Status TanhGrad(const Scope& scope, const Operation& op,
205                const std::vector<Output>& grad_inputs,
206                std::vector<Output>* grad_outputs) {
207  // y = tanh(x)
208  // dy/dx = 1 - (tanh(x))^2 = 1 - y^2
209  auto y2 = Square(scope, op.output(0));
210  auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
211  auto dydx = Sub(scope, one, y2);
212  // grad(x) = grad(y) * conj(dy/dx)
213  grad_outputs->push_back(
214      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
215  return scope.status();
216}
217REGISTER_GRADIENT_OP("Tanh", TanhGrad);
218
219Status AsinhGrad(const Scope& scope, const Operation& op,
220                 const std::vector<Output>& grad_inputs,
221                 std::vector<Output>* grad_outputs) {
222  // y = asinh(x)
223  // dy/dx = 1 / cosh(y)
224  auto dydx = Reciprocal(scope, Cosh(scope, op.output(0)));
225  // grad(x) = grad(y) * conj(dy/dx)
226  grad_outputs->push_back(
227      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
228  return scope.status();
229}
230REGISTER_GRADIENT_OP("Asinh", AsinhGrad);
231
232Status AcoshGrad(const Scope& scope, const Operation& op,
233                 const std::vector<Output>& grad_inputs,
234                 std::vector<Output>* grad_outputs) {
235  // y = acosh(x)
236  // dy/dx = 1 / sinh(y)
237  auto dydx = Reciprocal(scope, Sinh(scope, op.output(0)));
238  // grad(x) = grad(y) * conj(dy/dx)
239  grad_outputs->push_back(
240      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
241  return scope.status();
242}
243REGISTER_GRADIENT_OP("Acosh", AcoshGrad);
244
245Status AtanhGrad(const Scope& scope, const Operation& op,
246                 const std::vector<Output>& grad_inputs,
247                 std::vector<Output>* grad_outputs) {
248  // y = atanh(x)
249  // dy/dx = 1 / (1 - x^2)
250  auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
251  auto dydx = Reciprocal(scope, Sub(scope, one, Square(scope, op.input(0))));
252  // grad(x) = grad(y) * conj(dy/dx)
253  grad_outputs->push_back(
254      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
255  return scope.status();
256}
257REGISTER_GRADIENT_OP("Atanh", AtanhGrad);
258
259Status SigmoidGrad(const Scope& scope, const Operation& op,
260                   const std::vector<Output>& grad_inputs,
261                   std::vector<Output>* grad_outputs) {
262  // y = 1 / (1 + exp(-x))
263  // dy/dx = y * (1 - y)
264  auto y = op.output(0);
265  auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
266  auto dydx = Mul(scope, y, Sub(scope, one, y));
267  // dx = dy * y * (1 - y)
268  // grad(x) = grad(y) * conj(dy/dx)
269  grad_outputs->push_back(
270      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
271  return scope.status();
272}
273REGISTER_GRADIENT_OP("Sigmoid", SigmoidGrad);
274
275Status SignGrad(const Scope& scope, const Operation& op,
276                const std::vector<Output>& grad_inputs,
277                std::vector<Output>* grad_outputs) {
278  auto shape = Shape(scope, op.input(0));
279  auto zero = Cast(scope, Const(scope, 0.0), op.input(0).type());
280  auto dx = Fill(scope, shape, zero);
281  grad_outputs->push_back(dx);
282  return scope.status();
283}
284REGISTER_GRADIENT_OP("Sign", SignGrad);
285
286Status SinGrad(const Scope& scope, const Operation& op,
287               const std::vector<Output>& grad_inputs,
288               std::vector<Output>* grad_outputs) {
289  // y = sin(x)
290  // dy/dx = cos(x)
291  auto dydx = Cos(scope, op.input(0));
292  // grad(x) = grad(y) * conj(dy/dx)
293  grad_outputs->push_back(
294      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
295  return scope.status();
296}
297REGISTER_GRADIENT_OP("Sin", SinGrad);
298
299Status CosGrad(const Scope& scope, const Operation& op,
300               const std::vector<Output>& grad_inputs,
301               std::vector<Output>* grad_outputs) {
302  // y = cos(x)
303  // dy/dx = -sin(x)
304  auto dydx = Neg(scope, Sin(scope, op.input(0)));
305  // grad(x) = grad(y) * conj(dy/dx)
306  grad_outputs->push_back(
307      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
308  return scope.status();
309}
310REGISTER_GRADIENT_OP("Cos", CosGrad);
311
312Status AsinGrad(const Scope& scope, const Operation& op,
313                const std::vector<Output>& grad_inputs,
314                std::vector<Output>* grad_outputs) {
315  // y = asin(x)
316  // dy/dx = 1 / sqrt(1 - x^2)
317  auto x2 = Square(scope, op.input(0));
318  auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
319  auto dydx = Reciprocal(scope, Sqrt(scope, Sub(scope, one, x2)));
320  // grad(x) = grad(y) * conj(dy/dx)
321  auto dx = Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx));
322  grad_outputs->push_back(dx);
323  return scope.status();
324}
325REGISTER_GRADIENT_OP("Asin", AsinGrad);
326
327Status AcosGrad(const Scope& scope, const Operation& op,
328                const std::vector<Output>& grad_inputs,
329                std::vector<Output>* grad_outputs) {
330  // y = acos(x)
331  // dy/dx = - 1 / (1 - x * x)^1/2
332  // dx = dy * (- 1 / (1 - x * x)^1/2)
333  auto x2 = Square(scope, op.input(0));
334  auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
335  auto dydx = Neg(scope, Reciprocal(scope, Sqrt(scope, Sub(scope, one, x2))));
336  auto dx = Mul(scope, grad_inputs[0], dydx);
337  grad_outputs->push_back(dx);
338  return scope.status();
339}
340REGISTER_GRADIENT_OP("Acos", AcosGrad);
341
342Status TanGrad(const Scope& scope, const Operation& op,
343               const std::vector<Output>& grad_inputs,
344               std::vector<Output>* grad_outputs) {
345  // y = tan(x)
346  // dy/dx = sec(x)^2 = 1 / cos(x)^2
347  auto dydx = Square(scope, Reciprocal(scope, Cos(scope, op.input(0))));
348  // grad(x) = grad(y) * conj(dy/dx)
349  auto dx = Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx));
350  grad_outputs->push_back(dx);
351  return scope.status();
352}
353REGISTER_GRADIENT_OP("Tan", TanGrad);
354
355Status AtanGrad(const Scope& scope, const Operation& op,
356                const std::vector<Output>& grad_inputs,
357                std::vector<Output>* grad_outputs) {
358  // y = arctan(x)
359  // dy/dx = 1 / (1 + x^2)
360  // dx = dy * (1 / (1 + x^2)
361  auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
362  auto dydx = Reciprocal(scope, Add(scope, one, Square(scope, op.input(0))));
363  auto dx = Mul(scope, grad_inputs[0], dydx);
364  grad_outputs->push_back(dx);
365  return scope.status();
366}
367REGISTER_GRADIENT_OP("Atan", AtanGrad);
368
369// BinaryGradCommon handles the setup for binary ops that broadcast
370// their inputs.
371Status BinaryGradCommon(const Scope& scope, const Operation& op,
372                        std::vector<Output>* grad_outputs, const Output& gx_1,
373                        const Output& gx_2) {
374  auto sx_1 = Shape(scope, op.input(0));
375  auto sx_2 = Shape(scope, op.input(1));
376  auto rx = ops::internal::BroadcastGradientArgs(scope, sx_1, sx_2);
377  auto dx_1 = Reshape(scope, Sum(scope, gx_1, rx.r0), sx_1);
378  auto dx_2 = Reshape(scope, Sum(scope, gx_2, rx.r1), sx_2);
379  grad_outputs->push_back(dx_1);
380  grad_outputs->push_back(dx_2);
381  return scope.status();
382}
383
384Status AddGrad(const Scope& scope, const Operation& op,
385               const std::vector<Output>& grad_inputs,
386               std::vector<Output>* grad_outputs) {
387  // y = x_1 + x_2
388  // dy/dx_1 = dy/dx_2 = 1
389  auto gx_1 = Identity(scope, grad_inputs[0]);
390  auto gx_2 = Identity(scope, grad_inputs[0]);
391  return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
392}
393REGISTER_GRADIENT_OP("Add", AddGrad);
394
395Status SubGrad(const Scope& scope, const Operation& op,
396               const std::vector<Output>& grad_inputs,
397               std::vector<Output>* grad_outputs) {
398  // y = x_1 - x_2
399  // dy/dx_1 = 1
400  // dy/dx_2 = -1
401  auto gx_1 = Identity(scope, grad_inputs[0]);
402  auto gx_2 = Neg(scope, grad_inputs[0]);
403  return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
404}
405REGISTER_GRADIENT_OP("Sub", SubGrad);
406
407Status MulGrad(const Scope& scope, const Operation& op,
408               const std::vector<Output>& grad_inputs,
409               std::vector<Output>* grad_outputs) {
410  auto x_1 = ConjugateHelper(scope, op.input(0));
411  auto x_2 = ConjugateHelper(scope, op.input(1));
412  // y = x_1 * x_2
413  // dy/dx_1 = x_2
414  // dy/dx_2 = x_1
415  auto gx_1 = Mul(scope, grad_inputs[0], x_2);
416  auto gx_2 = Mul(scope, grad_inputs[0], x_1);
417  return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
418}
419REGISTER_GRADIENT_OP("Mul", MulGrad);
420
421Status DivGrad(const Scope& scope, const Operation& op,
422               const std::vector<Output>& grad_inputs,
423               std::vector<Output>* grad_outputs) {
424  auto x_1 = ConjugateHelper(scope, op.input(0));
425  auto x_2 = ConjugateHelper(scope, op.input(1));
426  // y = x_1 / x_2
427  // dy/dx_1 = 1/x_2
428  // dy/dx_2 = -x_1/x_2^2
429  auto gx_1 = Div(scope, grad_inputs[0], x_2);
430  auto gx_2 = Mul(scope, grad_inputs[0],
431                  Div(scope, Div(scope, Neg(scope, x_1), x_2), x_2));
432  return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
433}
434REGISTER_GRADIENT_OP("Div", DivGrad);
435
436Status RealDivGrad(const Scope& scope, const Operation& op,
437                   const std::vector<Output>& grad_inputs,
438                   std::vector<Output>* grad_outputs) {
439  auto x_1 = ConjugateHelper(scope, op.input(0));
440  auto x_2 = ConjugateHelper(scope, op.input(1));
441  // y = x_1 / x_2
442  // dy/dx_1 = 1/x_2
443  // dy/dx_2 = -x_1/x_2^2
444  auto gx_1 = RealDiv(scope, grad_inputs[0], x_2);
445  auto gx_2 = Mul(scope, grad_inputs[0],
446                  RealDiv(scope, RealDiv(scope, Neg(scope, x_1), x_2), x_2));
447  return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
448}
449REGISTER_GRADIENT_OP("RealDiv", RealDivGrad);
450
451Status SquaredDifferenceGrad(const Scope& scope, const Operation& op,
452                             const std::vector<Output>& grad_inputs,
453                             std::vector<Output>* grad_outputs) {
454  auto x_1 = ConjugateHelper(scope, op.input(0));
455  auto x_2 = ConjugateHelper(scope, op.input(1));
456  // y = (x_1 - x_2)^2
457  // dy/dx_1 = 2 * (x_1 - x_2)
458  // dy/dx_2 = -2 * (x_1 - x_2)
459  auto two = Cast(scope, Const(scope, 2), grad_inputs[0].type());
460  auto gx_1 = Mul(scope, grad_inputs[0], Mul(scope, two, Sub(scope, x_1, x_2)));
461  auto gx_2 = Neg(scope, gx_1);
462  return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
463}
464REGISTER_GRADIENT_OP("SquaredDifference", SquaredDifferenceGrad);
465
466Status AddNGrad(const Scope& scope, const Operation& op,
467                const std::vector<Output>& grad_inputs,
468                std::vector<Output>* grad_outputs) {
469  // AddN doesn't support broadcasting, so all the inputs must be the
470  // same shape.
471  // Note:
472  // dy/dx_k = d(x_1 + x_2 + ... + x_n)/dx_k = 1 for all x_k
473  // hence dx_k = dy for all x_k
474  // So the gradient for AddN just transfers the incoming gradient to
475  // all outgoing gradients.
476  auto incoming = Identity(scope, grad_inputs[0]);
477  for (int32 i = 0; i < op.num_inputs(); ++i) {
478    grad_outputs->push_back(incoming);
479  }
480  return scope.status();
481}
482REGISTER_GRADIENT_OP("AddN", AddNGrad);
483
484// MaximumMinimumGradCommon adds shared ops to calculate gradients for
485// the binary Maximum and Minimum ops.
486Status MaximumMinimumGradCommon(const Scope& scope, const Operation& op,
487                                const std::vector<Output>& grad_inputs,
488                                std::vector<Output>* grad_outputs,
489                                const Output& comparator) {
490  // comparator is a boolean tensor, with
491  // y = x_1 at points where comparator is true, and x_2 otherwise
492  // Therefore
493  // dy/dx_1 = 1 where comparator is true, and 0 otherwise.
494  // dy/dx_2 = 0 where comparator is true, and 1 otherwise.
495  auto grad = grad_inputs[0];
496  auto zeros = ZerosLike(scope, grad);
497  auto gx_1 = Where3(scope, comparator, grad, zeros);
498  auto gx_2 = Where3(scope, LogicalNot(scope, comparator), grad, zeros);
499  return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
500}
501
502Status MaximumGrad(const Scope& scope, const Operation& op,
503                   const std::vector<Output>& grad_inputs,
504                   std::vector<Output>* grad_outputs) {
505  auto comparator = GreaterEqual(scope, op.input(0), op.input(1));
506  return MaximumMinimumGradCommon(scope, op, grad_inputs, grad_outputs,
507                                  comparator);
508}
509REGISTER_GRADIENT_OP("Maximum", MaximumGrad);
510
511Status MinimumGrad(const Scope& scope, const Operation& op,
512                   const std::vector<Output>& grad_inputs,
513                   std::vector<Output>* grad_outputs) {
514  auto comparator = LessEqual(scope, op.input(0), op.input(1));
515  return MaximumMinimumGradCommon(scope, op, grad_inputs, grad_outputs,
516                                  comparator);
517}
518REGISTER_GRADIENT_OP("Minimum", MinimumGrad);
519
520Status RealGrad(const Scope& scope, const Operation& op,
521                const std::vector<Output>& grad_inputs,
522                std::vector<Output>* grad_outputs) {
523  auto zero = Cast(scope, Const(scope, 0.0), op.output(0).type());
524  auto dx = Complex(scope, grad_inputs[0], zero);
525  grad_outputs->push_back(dx);
526  return scope.status();
527}
528REGISTER_GRADIENT_OP("Real", RealGrad);
529
530Status ImagGrad(const Scope& scope, const Operation& op,
531                const std::vector<Output>& grad_inputs,
532                std::vector<Output>* grad_outputs) {
533  auto zero = Cast(scope, Const(scope, 0.0), op.output(0).type());
534  auto dx = Complex(scope, zero, grad_inputs[0]);
535  grad_outputs->push_back(dx);
536  return scope.status();
537}
538REGISTER_GRADIENT_OP("Imag", ImagGrad);
539
540Status ConjGrad(const Scope& scope, const Operation& op,
541                const std::vector<Output>& grad_inputs,
542                std::vector<Output>* grad_outputs) {
543  grad_outputs->push_back(Conj(scope, grad_inputs[0]));
544  return scope.status();
545}
546REGISTER_GRADIENT_OP("Conj", ConjGrad);
547
548// MatMulGrad helper function used to compute two MatMul operations
549// based on input matrix transposition combinations.
550Status MatMulGradHelper(const Scope& scope, const bool is_batch,
551                        const Output& x0, const bool adj_x0, const Output& x1,
552                        const bool adj_x1, const Output& y0, const bool adj_y0,
553                        const Output& y1, const bool adj_y1,
554                        std::vector<Output>* grad_outputs) {
555  if (is_batch == false) {
556    auto dx =
557        MatMul(scope, x0, x1, MatMul::TransposeA(adj_x0).TransposeB(adj_x1));
558    grad_outputs->push_back(dx);
559    auto dy =
560        MatMul(scope, y0, y1, MatMul::TransposeA(adj_y0).TransposeB(adj_y1));
561    grad_outputs->push_back(dy);
562  } else {
563    auto dx =
564        BatchMatMul(scope, x0, x1, BatchMatMul::AdjX(adj_x0).AdjY(adj_x1));
565    grad_outputs->push_back(dx);
566    auto dy =
567        BatchMatMul(scope, y0, y1, BatchMatMul::AdjX(adj_y0).AdjY(adj_y1));
568    grad_outputs->push_back(dy);
569  }
570  return scope.status();
571}
572
573// MatMulGrad common used to read and check node attr state, and determine
574// proper MatMul products for gradients based on input matrix transposition
575// combinations.
576// TODO(andydavis) Re-use this function for BatchMatMulGrad.
577Status MatMulGradCommon(const Scope& scope, const Operation& op,
578                        const bool is_batch,
579                        const std::vector<Output>& grad_inputs,
580                        const string& attr_adj_x, const string& attr_adj_y,
581                        std::vector<Output>* grad_outputs) {
582  DataType dtype;
583  TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->attrs(), "T", &dtype));
584  if (dtype == DT_COMPLEX64 || dtype == DT_COMPLEX128) {
585    return errors::Unimplemented(
586        "MatMul gradient for complex data type is not supported yet.");
587  }
588
589  bool ta;
590  bool tb;
591  TF_RETURN_IF_ERROR(
592      GetNodeAttr(op.output(0).node()->attrs(), attr_adj_x, &ta));
593  TF_RETURN_IF_ERROR(
594      GetNodeAttr(op.output(0).node()->attrs(), attr_adj_y, &tb));
595
596  if (!ta && !tb) {
597    return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, op.input(1),
598                            true, op.input(0), true, grad_inputs[0], false,
599                            grad_outputs);
600  } else if (!ta && tb) {
601    return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, op.input(1),
602                            false, grad_inputs[0], true, op.input(0), false,
603                            grad_outputs);
604  } else if (ta && !tb) {
605    return MatMulGradHelper(scope, is_batch, op.input(1), false, grad_inputs[0],
606                            true, op.input(0), false, grad_inputs[0], false,
607                            grad_outputs);
608  }
609  return MatMulGradHelper(scope, is_batch, op.input(1), true, grad_inputs[0],
610                          true, grad_inputs[0], true, op.input(0), true,
611                          grad_outputs);
612}
613
614Status MatMulGrad(const Scope& scope, const Operation& op,
615                  const std::vector<Output>& grad_inputs,
616                  std::vector<Output>* grad_outputs) {
617  return MatMulGradCommon(scope, op, false, grad_inputs, "transpose_a",
618                          "transpose_b", grad_outputs);
619}
620REGISTER_GRADIENT_OP("MatMul", MatMulGrad);
621
622Status BatchMatMulGrad(const Scope& scope, const Operation& op,
623                       const std::vector<Output>& grad_inputs,
624                       std::vector<Output>* grad_outputs) {
625  return MatMulGradCommon(scope, op, true, grad_inputs, "adj_x", "adj_y",
626                          grad_outputs);
627}
628REGISTER_GRADIENT_OP("BatchMatMul", BatchMatMulGrad);
629
630}  // anonymous namespace
631}  // namespace ops
632}  // namespace tensorflow
633