math_grad.cc revision 5a1d6d9dac79b46f055462ee52125753524d9f6e
1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/cc/ops/array_ops_internal.h"
17#include "tensorflow/cc/ops/math_ops_internal.h"
18#include "tensorflow/cc/ops/standard_ops.h"
19
20#include "tensorflow/cc/framework/grad_op_registry.h"
21#include "tensorflow/cc/framework/gradients.h"
22
23namespace tensorflow {
24namespace ops {
25namespace {
26
27// Logical operations have no gradients.
28REGISTER_NO_GRADIENT_OP("Less");
29REGISTER_NO_GRADIENT_OP("LessEqual");
30REGISTER_NO_GRADIENT_OP("Greater");
31REGISTER_NO_GRADIENT_OP("GreaterEqual");
32REGISTER_NO_GRADIENT_OP("Equal");
33REGISTER_NO_GRADIENT_OP("ApproximateEqual");
34REGISTER_NO_GRADIENT_OP("NotEqual");
35REGISTER_NO_GRADIENT_OP("LogicalAnd");
36REGISTER_NO_GRADIENT_OP("LogicalOr");
37REGISTER_NO_GRADIENT_OP("LogicalNot");
38
39// Conjugate helper function returns the conjugate of an Output if it
40// is complex valued.
41Output ConjugateHelper(const Scope& scope, const Output& out) {
42  DataType dtype = out.type();
43  if (dtype == DT_COMPLEX64 || dtype == DT_COMPLEX128) {
44    return Conj(scope, out);
45  } else {
46    return out;
47  }
48}
49
50// TODO(andydavis) Add control dependencies to gradient functions (as needed).
51
52Status AbsGrad(const Scope& scope, const Operation& op,
53               const std::vector<Output>& grad_inputs,
54               std::vector<Output>* grad_outputs) {
55  // dx = dy * sign(x)
56  grad_outputs->push_back(Mul(scope, grad_inputs[0], Sign(scope, op.input(0))));
57  return scope.status();
58}
59REGISTER_GRADIENT_OP("Abs", AbsGrad);
60
61Status NegGrad(const Scope& scope, const Operation& op,
62               const std::vector<Output>& grad_inputs,
63               std::vector<Output>* grad_outputs) {
64  // dx = -dy;
65  grad_outputs->push_back(Neg(scope, grad_inputs[0]));
66  return scope.status();
67}
68REGISTER_GRADIENT_OP("Neg", NegGrad);
69
70Status InvGrad(const Scope& scope, const Operation& op,
71               const std::vector<Output>& grad_inputs,
72               std::vector<Output>* grad_outputs) {
73  // Use the built-in operator.
74  grad_outputs->push_back(
75      internal::ReciprocalGrad(scope, op.output(0), grad_inputs[0]));
76  return scope.status();
77}
78REGISTER_GRADIENT_OP("Inv", InvGrad);
79REGISTER_GRADIENT_OP("Reciprocal", InvGrad);
80
81Status SquareGrad(const Scope& scope, const Operation& op,
82                  const std::vector<Output>& grad_inputs,
83                  std::vector<Output>* grad_outputs) {
84  // dy/dx = (2 * x)
85  auto two = Cast(scope, Const(scope, 2), op.input(0).type());
86  auto dydx = Mul(scope, two, op.input(0));
87  // grad(x) = grad(y) * conj(dy/dx)
88  grad_outputs->push_back(
89      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
90  return scope.status();
91}
92REGISTER_GRADIENT_OP("Square", SquareGrad);
93
94Status SqrtGrad(const Scope& scope, const Operation& op,
95                const std::vector<Output>& grad_inputs,
96                std::vector<Output>* grad_outputs) {
97  // Use the built-in operator.
98  grad_outputs->push_back(
99      internal::SqrtGrad(scope, op.output(0), grad_inputs[0]));
100  return scope.status();
101}
102REGISTER_GRADIENT_OP("Sqrt", SqrtGrad);
103
104Status RsqrtGrad(const Scope& scope, const Operation& op,
105                 const std::vector<Output>& grad_inputs,
106                 std::vector<Output>* grad_outputs) {
107  // Use the built-in operator.
108  grad_outputs->push_back(
109      internal::RsqrtGrad(scope, op.output(0), grad_inputs[0]));
110  return scope.status();
111}
112REGISTER_GRADIENT_OP("Rsqrt", RsqrtGrad);
113
114Status ExpGrad(const Scope& scope, const Operation& op,
115               const std::vector<Output>& grad_inputs,
116               std::vector<Output>* grad_outputs) {
117  // dy/dx = exp(x) = y
118  // grad(x) = grad(y) * conj(dy/dx)
119  //         = grad(y) * conj(y)
120  grad_outputs->push_back(
121      Mul(scope, grad_inputs[0], ConjugateHelper(scope, op.output(0))));
122  return scope.status();
123}
124REGISTER_GRADIENT_OP("Exp", ExpGrad);
125
126Status Expm1Grad(const Scope& scope, const Operation& op,
127                 const std::vector<Output>& grad_inputs,
128                 std::vector<Output>* grad_outputs) {
129  // y = expm1(x)
130  // dy/dx = exp(x)
131  auto dydx = Exp(scope, op.input(0));
132  // grad(x) = grad(y) * conj(dy/dx)
133  grad_outputs->push_back(
134      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
135  return scope.status();
136}
137REGISTER_GRADIENT_OP("Expm1", Expm1Grad);
138
139Status LogGrad(const Scope& scope, const Operation& op,
140               const std::vector<Output>& grad_inputs,
141               std::vector<Output>* grad_outputs) {
142  // y = log(x)
143  // dy/dx = 1 / x
144  auto dydx = Reciprocal(scope, op.input(0));
145  // grad(x) = grad(y) * conj(dy/dx)
146  grad_outputs->push_back(
147      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
148  return scope.status();
149}
150REGISTER_GRADIENT_OP("Log", LogGrad);
151
152Status Log1pGrad(const Scope& scope, const Operation& op,
153                 const std::vector<Output>& grad_inputs,
154                 std::vector<Output>* grad_outputs) {
155  // y = log1p(x)
156  // dy/dx = 1 / (1 + x)
157  auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
158  auto dydx = Reciprocal(scope, Add(scope, one, op.input(0)));
159  // grad(x) = grad(y) * conj(dy/dx)
160  grad_outputs->push_back(
161      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
162  return scope.status();
163}
164REGISTER_GRADIENT_OP("Log1p", Log1pGrad);
165
166Status SinhGrad(const Scope& scope, const Operation& op,
167                const std::vector<Output>& grad_inputs,
168                std::vector<Output>* grad_outputs) {
169  // y = sinh(x)
170  // dy/dx = cosh(x)
171  auto dydx = Cosh(scope, op.input(0));
172  // grad(x) = grad(y) * conj(dy/dx)
173  grad_outputs->push_back(
174      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
175  return scope.status();
176}
177REGISTER_GRADIENT_OP("Sinh", SinhGrad);
178
179Status CoshGrad(const Scope& scope, const Operation& op,
180                const std::vector<Output>& grad_inputs,
181                std::vector<Output>* grad_outputs) {
182  // y = cosh(x)
183  // dy/dx = sinh(x)
184  auto dydx = Sinh(scope, op.input(0));
185  // grad(x) = grad(y) * conj(dy/dx)
186  grad_outputs->push_back(
187      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
188  return scope.status();
189}
190REGISTER_GRADIENT_OP("Cosh", CoshGrad);
191
192Status TanhGrad(const Scope& scope, const Operation& op,
193                const std::vector<Output>& grad_inputs,
194                std::vector<Output>* grad_outputs) {
195  // Use the built-in operator.
196  // Note that the built-in operator does not return the conjugate of
197  // the gradient.
198  auto grad = grad_inputs[0];
199  // Optimization to avoid calculating conj(y) until the gradient is
200  // evaluated.
201  Scope grad_scope = scope.WithControlDependencies(grad);
202  auto y = ConjugateHelper(grad_scope, op.output(0));
203  grad_outputs->push_back(internal::TanhGrad(scope, y, grad));
204  return scope.status();
205}
206REGISTER_GRADIENT_OP("Tanh", TanhGrad);
207
208Status AsinhGrad(const Scope& scope, const Operation& op,
209                 const std::vector<Output>& grad_inputs,
210                 std::vector<Output>* grad_outputs) {
211  // y = asinh(x)
212  // dy/dx = 1 / cosh(y)
213  auto dydx = Reciprocal(scope, Cosh(scope, op.output(0)));
214  // grad(x) = grad(y) * conj(dy/dx)
215  grad_outputs->push_back(
216      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
217  return scope.status();
218}
219REGISTER_GRADIENT_OP("Asinh", AsinhGrad);
220
221Status AcoshGrad(const Scope& scope, const Operation& op,
222                 const std::vector<Output>& grad_inputs,
223                 std::vector<Output>* grad_outputs) {
224  // y = acosh(x)
225  // dy/dx = 1 / sinh(y)
226  auto dydx = Reciprocal(scope, Sinh(scope, op.output(0)));
227  // grad(x) = grad(y) * conj(dy/dx)
228  grad_outputs->push_back(
229      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
230  return scope.status();
231}
232REGISTER_GRADIENT_OP("Acosh", AcoshGrad);
233
234Status AtanhGrad(const Scope& scope, const Operation& op,
235                 const std::vector<Output>& grad_inputs,
236                 std::vector<Output>* grad_outputs) {
237  // y = atanh(x)
238  // dy/dx = 1 / (1 - x^2)
239  auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
240  auto dydx = Reciprocal(scope, Sub(scope, one, Square(scope, op.input(0))));
241  // grad(x) = grad(y) * conj(dy/dx)
242  grad_outputs->push_back(
243      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
244  return scope.status();
245}
246REGISTER_GRADIENT_OP("Atanh", AtanhGrad);
247
248Status SigmoidGrad(const Scope& scope, const Operation& op,
249                   const std::vector<Output>& grad_inputs,
250                   std::vector<Output>* grad_outputs) {
251  // Use the built-in operator.
252  // Note that the built-in operator does not return the conjugate of
253  // the gradient.
254  auto grad = grad_inputs[0];
255  // Optimization to avoid calculating conj(y) until the gradient is
256  // evaluated.
257  Scope grad_scope = scope.WithControlDependencies(grad);
258  auto y = ConjugateHelper(grad_scope, op.output(0));
259  grad_outputs->push_back(internal::SigmoidGrad(scope, y, grad));
260  return scope.status();
261}
262REGISTER_GRADIENT_OP("Sigmoid", SigmoidGrad);
263
264Status SignGrad(const Scope& scope, const Operation& op,
265                const std::vector<Output>& grad_inputs,
266                std::vector<Output>* grad_outputs) {
267  auto shape = Shape(scope, op.input(0));
268  auto zero = Cast(scope, Const(scope, 0.0), op.input(0).type());
269  auto dx = Fill(scope, shape, zero);
270  grad_outputs->push_back(dx);
271  return scope.status();
272}
273REGISTER_GRADIENT_OP("Sign", SignGrad);
274
275Status SinGrad(const Scope& scope, const Operation& op,
276               const std::vector<Output>& grad_inputs,
277               std::vector<Output>* grad_outputs) {
278  // y = sin(x)
279  // dy/dx = cos(x)
280  auto dydx = Cos(scope, op.input(0));
281  // grad(x) = grad(y) * conj(dy/dx)
282  grad_outputs->push_back(
283      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
284  return scope.status();
285}
286REGISTER_GRADIENT_OP("Sin", SinGrad);
287
288Status CosGrad(const Scope& scope, const Operation& op,
289               const std::vector<Output>& grad_inputs,
290               std::vector<Output>* grad_outputs) {
291  // y = cos(x)
292  // dy/dx = -sin(x)
293  auto dydx = Neg(scope, Sin(scope, op.input(0)));
294  // grad(x) = grad(y) * conj(dy/dx)
295  grad_outputs->push_back(
296      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
297  return scope.status();
298}
299REGISTER_GRADIENT_OP("Cos", CosGrad);
300
301Status AsinGrad(const Scope& scope, const Operation& op,
302                const std::vector<Output>& grad_inputs,
303                std::vector<Output>* grad_outputs) {
304  // y = asin(x)
305  // dy/dx = 1 / sqrt(1 - x^2)
306  auto x2 = Square(scope, op.input(0));
307  auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
308  auto dydx = Reciprocal(scope, Sqrt(scope, Sub(scope, one, x2)));
309  // grad(x) = grad(y) * conj(dy/dx)
310  auto dx = Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx));
311  grad_outputs->push_back(dx);
312  return scope.status();
313}
314REGISTER_GRADIENT_OP("Asin", AsinGrad);
315
316Status AcosGrad(const Scope& scope, const Operation& op,
317                const std::vector<Output>& grad_inputs,
318                std::vector<Output>* grad_outputs) {
319  // y = acos(x)
320  // dy/dx = - 1 / (1 - x * x)^1/2
321  // dx = dy * (- 1 / (1 - x * x)^1/2)
322  auto x2 = Square(scope, op.input(0));
323  auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
324  auto dydx = Neg(scope, Reciprocal(scope, Sqrt(scope, Sub(scope, one, x2))));
325  auto dx = Mul(scope, grad_inputs[0], dydx);
326  grad_outputs->push_back(dx);
327  return scope.status();
328}
329REGISTER_GRADIENT_OP("Acos", AcosGrad);
330
331Status TanGrad(const Scope& scope, const Operation& op,
332               const std::vector<Output>& grad_inputs,
333               std::vector<Output>* grad_outputs) {
334  // y = tan(x)
335  // dy/dx = sec(x)^2 = 1 / cos(x)^2
336  auto dydx = Square(scope, Reciprocal(scope, Cos(scope, op.input(0))));
337  // grad(x) = grad(y) * conj(dy/dx)
338  auto dx = Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx));
339  grad_outputs->push_back(dx);
340  return scope.status();
341}
342REGISTER_GRADIENT_OP("Tan", TanGrad);
343
344Status AtanGrad(const Scope& scope, const Operation& op,
345                const std::vector<Output>& grad_inputs,
346                std::vector<Output>* grad_outputs) {
347  // y = arctan(x)
348  // dy/dx = 1 / (1 + x^2)
349  // dx = dy * (1 / (1 + x^2)
350  auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
351  auto dydx = Reciprocal(scope, Add(scope, one, Square(scope, op.input(0))));
352  auto dx = Mul(scope, grad_inputs[0], dydx);
353  grad_outputs->push_back(dx);
354  return scope.status();
355}
356REGISTER_GRADIENT_OP("Atan", AtanGrad);
357
358// BinaryGradCommon handles the setup for binary ops that broadcast
359// their inputs.
360Status BinaryGradCommon(const Scope& scope, const Operation& op,
361                        std::vector<Output>* grad_outputs, const Output& gx_1,
362                        const Output& gx_2) {
363  auto sx_1 = Shape(scope, op.input(0));
364  auto sx_2 = Shape(scope, op.input(1));
365  auto rx = internal::BroadcastGradientArgs(scope, sx_1, sx_2);
366  auto dx_1 = Reshape(scope, Sum(scope, gx_1, rx.r0), sx_1);
367  auto dx_2 = Reshape(scope, Sum(scope, gx_2, rx.r1), sx_2);
368  grad_outputs->push_back(dx_1);
369  grad_outputs->push_back(dx_2);
370  return scope.status();
371}
372
373Status AddGrad(const Scope& scope, const Operation& op,
374               const std::vector<Output>& grad_inputs,
375               std::vector<Output>* grad_outputs) {
376  // y = x_1 + x_2
377  // dy/dx_1 = dy/dx_2 = 1
378  auto gx_1 = Identity(scope, grad_inputs[0]);
379  auto gx_2 = Identity(scope, grad_inputs[0]);
380  return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
381}
382REGISTER_GRADIENT_OP("Add", AddGrad);
383
384Status SubGrad(const Scope& scope, const Operation& op,
385               const std::vector<Output>& grad_inputs,
386               std::vector<Output>* grad_outputs) {
387  // y = x_1 - x_2
388  // dy/dx_1 = 1
389  // dy/dx_2 = -1
390  auto gx_1 = Identity(scope, grad_inputs[0]);
391  auto gx_2 = Neg(scope, grad_inputs[0]);
392  return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
393}
394REGISTER_GRADIENT_OP("Sub", SubGrad);
395
396Status MulGrad(const Scope& scope, const Operation& op,
397               const std::vector<Output>& grad_inputs,
398               std::vector<Output>* grad_outputs) {
399  auto x_1 = ConjugateHelper(scope, op.input(0));
400  auto x_2 = ConjugateHelper(scope, op.input(1));
401  // y = x_1 * x_2
402  // dy/dx_1 = x_2
403  // dy/dx_2 = x_1
404  auto gx_1 = Mul(scope, grad_inputs[0], x_2);
405  auto gx_2 = Mul(scope, grad_inputs[0], x_1);
406  return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
407}
408REGISTER_GRADIENT_OP("Mul", MulGrad);
409
410Status DivGrad(const Scope& scope, const Operation& op,
411               const std::vector<Output>& grad_inputs,
412               std::vector<Output>* grad_outputs) {
413  auto x_1 = ConjugateHelper(scope, op.input(0));
414  auto x_2 = ConjugateHelper(scope, op.input(1));
415  // y = x_1 / x_2
416  // dy/dx_1 = 1/x_2
417  // dy/dx_2 = -x_1/x_2^2
418  auto gx_1 = Div(scope, grad_inputs[0], x_2);
419  auto gx_2 = Mul(scope, grad_inputs[0],
420                  Div(scope, Div(scope, Neg(scope, x_1), x_2), x_2));
421  return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
422}
423REGISTER_GRADIENT_OP("Div", DivGrad);
424
425Status RealDivGrad(const Scope& scope, const Operation& op,
426                   const std::vector<Output>& grad_inputs,
427                   std::vector<Output>* grad_outputs) {
428  auto x_1 = ConjugateHelper(scope, op.input(0));
429  auto x_2 = ConjugateHelper(scope, op.input(1));
430  // y = x_1 / x_2
431  // dy/dx_1 = 1/x_2
432  // dy/dx_2 = -x_1/x_2^2
433  auto gx_1 = RealDiv(scope, grad_inputs[0], x_2);
434  auto gx_2 = Mul(scope, grad_inputs[0],
435                  RealDiv(scope, RealDiv(scope, Neg(scope, x_1), x_2), x_2));
436  return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
437}
438REGISTER_GRADIENT_OP("RealDiv", RealDivGrad);
439
440Status SquaredDifferenceGrad(const Scope& scope, const Operation& op,
441                             const std::vector<Output>& grad_inputs,
442                             std::vector<Output>* grad_outputs) {
443  auto x_1 = ConjugateHelper(scope, op.input(0));
444  auto x_2 = ConjugateHelper(scope, op.input(1));
445  // y = (x_1 - x_2)^2
446  // dy/dx_1 = 2 * (x_1 - x_2)
447  // dy/dx_2 = -2 * (x_1 - x_2)
448  auto two = Cast(scope, Const(scope, 2), grad_inputs[0].type());
449  auto gx_1 = Mul(scope, grad_inputs[0], Mul(scope, two, Sub(scope, x_1, x_2)));
450  auto gx_2 = Neg(scope, gx_1);
451  return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
452}
453REGISTER_GRADIENT_OP("SquaredDifference", SquaredDifferenceGrad);
454
455Status AddNGrad(const Scope& scope, const Operation& op,
456                const std::vector<Output>& grad_inputs,
457                std::vector<Output>* grad_outputs) {
458  // AddN doesn't support broadcasting, so all the inputs must be the
459  // same shape.
460  // Note:
461  // dy/dx_k = d(x_1 + x_2 + ... + x_n)/dx_k = 1 for all x_k
462  // hence dx_k = dy for all x_k
463  // So the gradient for AddN just transfers the incoming gradient to
464  // all outgoing gradients.
465  auto incoming = Identity(scope, grad_inputs[0]);
466  for (int32 i = 0; i < op.num_inputs(); ++i) {
467    grad_outputs->push_back(incoming);
468  }
469  return scope.status();
470}
471REGISTER_GRADIENT_OP("AddN", AddNGrad);
472
473// MaximumMinimumGradCommon adds shared ops to calculate gradients for
474// the binary Maximum and Minimum ops.
475Status MaximumMinimumGradCommon(const Scope& scope, const Operation& op,
476                                const std::vector<Output>& grad_inputs,
477                                std::vector<Output>* grad_outputs,
478                                const Output& comparator) {
479  // comparator is a boolean tensor, with
480  // y = x_1 at points where comparator is true, and x_2 otherwise
481  // Therefore
482  // dy/dx_1 = 1 where comparator is true, and 0 otherwise.
483  // dy/dx_2 = 0 where comparator is true, and 1 otherwise.
484  auto grad = grad_inputs[0];
485  auto zeros = ZerosLike(scope, grad);
486  auto gx_1 = Where3(scope, comparator, grad, zeros);
487  auto gx_2 = Where3(scope, LogicalNot(scope, comparator), grad, zeros);
488  return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
489}
490
491Status MaximumGrad(const Scope& scope, const Operation& op,
492                   const std::vector<Output>& grad_inputs,
493                   std::vector<Output>* grad_outputs) {
494  auto comparator = GreaterEqual(scope, op.input(0), op.input(1));
495  return MaximumMinimumGradCommon(scope, op, grad_inputs, grad_outputs,
496                                  comparator);
497}
498REGISTER_GRADIENT_OP("Maximum", MaximumGrad);
499
500Status MinimumGrad(const Scope& scope, const Operation& op,
501                   const std::vector<Output>& grad_inputs,
502                   std::vector<Output>* grad_outputs) {
503  auto comparator = LessEqual(scope, op.input(0), op.input(1));
504  return MaximumMinimumGradCommon(scope, op, grad_inputs, grad_outputs,
505                                  comparator);
506}
507REGISTER_GRADIENT_OP("Minimum", MinimumGrad);
508
509Status RealGrad(const Scope& scope, const Operation& op,
510                const std::vector<Output>& grad_inputs,
511                std::vector<Output>* grad_outputs) {
512  auto zero = Cast(scope, Const(scope, 0.0), op.output(0).type());
513  auto dx = Complex(scope, grad_inputs[0], zero);
514  grad_outputs->push_back(dx);
515  return scope.status();
516}
517REGISTER_GRADIENT_OP("Real", RealGrad);
518
519Status ImagGrad(const Scope& scope, const Operation& op,
520                const std::vector<Output>& grad_inputs,
521                std::vector<Output>* grad_outputs) {
522  auto zero = Cast(scope, Const(scope, 0.0), op.output(0).type());
523  auto dx = Complex(scope, zero, grad_inputs[0]);
524  grad_outputs->push_back(dx);
525  return scope.status();
526}
527REGISTER_GRADIENT_OP("Imag", ImagGrad);
528
529Status AngleGrad(const Scope& scope, const Operation& op,
530                 const std::vector<Output>& grad_inputs,
531                 std::vector<Output>* grad_outputs) {
532  // y = Angle(x)
533  // dx = -dy / (Im(x) + iRe(x)) = -dy * z
534  auto re = Real(scope, op.input(0));
535  auto im = Imag(scope, op.input(0));
536  auto z_inv = Reciprocal(scope, Complex(scope, im, re));
537  auto zero = Cast(scope, Const(scope, 0), grad_inputs[0].type());
538  auto grad = Complex(scope, grad_inputs[0], zero);
539  auto dx = Neg(scope, Mul(scope, grad, z_inv));
540  grad_outputs->push_back(dx);
541  return scope.status();
542}
543REGISTER_GRADIENT_OP("Angle", AngleGrad);
544
545Status ConjGrad(const Scope& scope, const Operation& op,
546                const std::vector<Output>& grad_inputs,
547                std::vector<Output>* grad_outputs) {
548  grad_outputs->push_back(Conj(scope, grad_inputs[0]));
549  return scope.status();
550}
551REGISTER_GRADIENT_OP("Conj", ConjGrad);
552
553// Integer division x / y, assuming x and y >=0, but treats x/0 = x
554Output SafeDivHelper(const Scope& scope, const Output& x, const Output& y) {
555  return Div(scope, x, Maximum(scope, y, Const(scope, 1)));
556}
557
558// Helper function for reduction ops.
559//
560// input_shape: 1-D Tensor, the shape of the Tensor being reduced.
561// axes: 1-D Tensor, the reduction axes.
562//   Note that the reduction indices are in the range
563//   -rank(input_shape), rank(input_shape)
564// returns a 1-D Tensor, the output shape as if keep_dims were set to True.
565Output ReducedShapeHelper(const Scope& scope, const Output& input_shape,
566                          const Output& reduction_axes) {
567  auto zero = Const(scope, 0);
568  auto one = Const(scope, 1);
569
570  // Running example in comments
571  // input_shape = [2, 3, 5, 7]
572  // axes = [1, 2]
573  // The result (a shape after a reduction with keep_dims=True)
574  // [2, 1, 1, 7]
575  //
576  // We can treat each entry in axes as an index into input_shape that
577  // should be replaced by 1.
578  // We use DynamicStitch to do this.
579
580  // input_rank = 4
581  auto input_rank = Size(scope, input_shape);
582
583  // Normalize any negative indices in the reduction_axes to positive
584  // values.
585  auto axes = Mod(scope, Add(scope, reduction_axes, input_rank), input_rank);
586
587  // This [0..input_rank) range of integers is used in DynamicStitch to
588  // first copy input_shape to the result.
589  // input_rank_range = [0, 1, 2, 3]
590  auto input_rank_range = Range(scope, zero, input_rank, one);
591
592  // A 1-filled tensor with the same shape as axes. DynamicStitch will
593  // merge these 1s (using axes for indices) to the correct
594  // position in the result.
595  // axes_ones = [1, 1]
596  auto axes_ones = OnesLike(scope, axes);
597
598  // using DynamicStitch:
599  // indices = { input_rank_range, axes }
600  //         = { [0, 1, 2, 3], [1, 2] }
601  // data = { input_shape, axes_ones }
602  //      = { [2, 3, 5, 7], [1, 1] }
603  // The input_rank_range entry in indices first replicates the
604  // input_shape to the result.
605  // The axes entry in indices then moves a 1 to each of its entries,
606  // resulting in
607  // [2, 1, 1, 7]
608  std::vector<Output> indices = {input_rank_range, axes};
609  std::vector<Output> data = {input_shape, axes_ones};
610  return DynamicStitch(scope, indices, data);
611}
612
613// SumGradHelper returns the gradient for the Sum operator, and is used
614// by SumGrad and MeanGrad.
615Output SumGradHelper(const Scope& scope, const Operation& op,
616                     const std::vector<Output>& grad_inputs) {
617  // The partial derivative for any input along a "reduced" dimension
618  // is just 1, so we only need replicate the output gradient on such a
619  // dimension to its "expanded" shape.
620  // Running example:
621  // input is
622  // [[a, b, c],
623  //  [d, e, f]]
624  // reduction_indices = [1]
625  // Sum = [a + b + c, d + e + f]
626  // if the gradient is [g1, g2]
627  // We want the propagated gradient to be
628  // [[g1, g1, g1],
629  //  [g2, g2, g2]]
630
631  // input_shape = [2, 3]
632  auto input_shape = Shape(scope, op.input(0));
633
634  // output_shape_kept_dims = [2, 1]
635  auto output_shape_kept_dims =
636      ReducedShapeHelper(scope, input_shape, op.input(1));
637
638  // This step "flips" any 1s with values from the input_shape, and
639  // replaces remaining entries with 1. This creates a shape that
640  // shows how much each dimension in the incoming gradient should be
641  // replicated.
642  // tile_scaling = [1, 3]
643  auto tile_scaling = SafeDivHelper(scope, input_shape, output_shape_kept_dims);
644
645  // grad = [[g1], [g2]]
646  auto grad = Reshape(scope, grad_inputs[0], output_shape_kept_dims);
647
648  // tile(grad, tile_scaling) = [[g1, g1, g1], [g2, g2, g2]]
649  return Tile(scope, grad, tile_scaling);
650}
651
652Status SumGrad(const Scope& scope, const Operation& op,
653               const std::vector<Output>& grad_inputs,
654               std::vector<Output>* grad_outputs) {
655  grad_outputs->push_back(SumGradHelper(scope, op, grad_inputs));
656
657  // Stop propagation along reduction_indices
658  grad_outputs->push_back(NoGradient());
659  return scope.status();
660}
661REGISTER_GRADIENT_OP("Sum", SumGrad);
662
663Status MeanGrad(const Scope& scope, const Operation& op,
664                const std::vector<Output>& grad_inputs,
665                std::vector<Output>* grad_outputs) {
666  // The Mean gradient is just like the Sum gradient, except that
667  // all gradients are also divided by the size of reduced groups.
668  auto sum_grad = SumGradHelper(scope, op, grad_inputs);
669
670  // The product of all entries in a tensor's shape is the total
671  // number of entries in the tensor. This step calculates
672  // n_input_entries/n_output_entries
673  // = group_size
674  auto input_shape = Shape(scope, op.input(0));
675  auto output_shape = Shape(scope, op.output(0));
676  auto zero = Const(scope, 0);
677  auto group_size = SafeDivHelper(scope, Prod(scope, input_shape, zero),
678                                  Prod(scope, output_shape, zero));
679
680  // propagate sum_grad/group_size
681  grad_outputs->push_back(
682      Div(scope, sum_grad, Cast(scope, group_size, sum_grad.type())));
683
684  // Stop propagation along reduction_indices
685  grad_outputs->push_back(NoGradient());
686  return scope.status();
687}
688REGISTER_GRADIENT_OP("Mean", MeanGrad);
689
690// MatMulGrad helper function used to compute two MatMul operations
691// based on input matrix transposition combinations.
692Status MatMulGradHelper(const Scope& scope, const bool is_batch,
693                        const Output& x0, const bool adj_x0, const Output& x1,
694                        const bool adj_x1, const Output& y0, const bool adj_y0,
695                        const Output& y1, const bool adj_y1,
696                        std::vector<Output>* grad_outputs) {
697  if (is_batch == false) {
698    auto dx =
699        MatMul(scope, x0, x1, MatMul::TransposeA(adj_x0).TransposeB(adj_x1));
700    grad_outputs->push_back(dx);
701    auto dy =
702        MatMul(scope, y0, y1, MatMul::TransposeA(adj_y0).TransposeB(adj_y1));
703    grad_outputs->push_back(dy);
704  } else {
705    auto dx =
706        BatchMatMul(scope, x0, x1, BatchMatMul::AdjX(adj_x0).AdjY(adj_x1));
707    grad_outputs->push_back(dx);
708    auto dy =
709        BatchMatMul(scope, y0, y1, BatchMatMul::AdjX(adj_y0).AdjY(adj_y1));
710    grad_outputs->push_back(dy);
711  }
712  return scope.status();
713}
714
715// MatMulGrad common used to read and check node attr state, and determine
716// proper MatMul products for gradients based on input matrix transposition
717// combinations.
718// TODO(andydavis) Re-use this function for BatchMatMulGrad.
719Status MatMulGradCommon(const Scope& scope, const Operation& op,
720                        const bool is_batch,
721                        const std::vector<Output>& grad_inputs,
722                        const string& attr_adj_x, const string& attr_adj_y,
723                        std::vector<Output>* grad_outputs) {
724  DataType dtype;
725  TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->attrs(), "T", &dtype));
726  if (dtype == DT_COMPLEX64 || dtype == DT_COMPLEX128) {
727    return errors::Unimplemented(
728        "MatMul gradient for complex data type is not supported yet.");
729  }
730
731  bool ta;
732  bool tb;
733  TF_RETURN_IF_ERROR(
734      GetNodeAttr(op.output(0).node()->attrs(), attr_adj_x, &ta));
735  TF_RETURN_IF_ERROR(
736      GetNodeAttr(op.output(0).node()->attrs(), attr_adj_y, &tb));
737
738  if (!ta && !tb) {
739    return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, op.input(1),
740                            true, op.input(0), true, grad_inputs[0], false,
741                            grad_outputs);
742  } else if (!ta && tb) {
743    return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, op.input(1),
744                            false, grad_inputs[0], true, op.input(0), false,
745                            grad_outputs);
746  } else if (ta && !tb) {
747    return MatMulGradHelper(scope, is_batch, op.input(1), false, grad_inputs[0],
748                            true, op.input(0), false, grad_inputs[0], false,
749                            grad_outputs);
750  }
751  return MatMulGradHelper(scope, is_batch, op.input(1), true, grad_inputs[0],
752                          true, grad_inputs[0], true, op.input(0), true,
753                          grad_outputs);
754}
755
756Status MatMulGrad(const Scope& scope, const Operation& op,
757                  const std::vector<Output>& grad_inputs,
758                  std::vector<Output>* grad_outputs) {
759  return MatMulGradCommon(scope, op, false, grad_inputs, "transpose_a",
760                          "transpose_b", grad_outputs);
761}
762REGISTER_GRADIENT_OP("MatMul", MatMulGrad);
763
764Status BatchMatMulGrad(const Scope& scope, const Operation& op,
765                       const std::vector<Output>& grad_inputs,
766                       std::vector<Output>* grad_outputs) {
767  return MatMulGradCommon(scope, op, true, grad_inputs, "adj_x", "adj_y",
768                          grad_outputs);
769}
770REGISTER_GRADIENT_OP("BatchMatMul", BatchMatMulGrad);
771
772}  // anonymous namespace
773}  // namespace ops
774}  // namespace tensorflow
775