math_grad.cc revision 355e25ebcab64e833dfc987638c3e6c79d838266
1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#define _USE_MATH_DEFINES
17#include <cmath>
18
19#include "tensorflow/cc/ops/array_ops_internal.h"
20#include "tensorflow/cc/ops/math_ops_internal.h"
21#include "tensorflow/cc/ops/standard_ops.h"
22
23#include "tensorflow/cc/framework/grad_op_registry.h"
24#include "tensorflow/cc/framework/gradients.h"
25
26namespace tensorflow {
27namespace ops {
28namespace {
29
30// Logical operations have no gradients.
31REGISTER_NO_GRADIENT_OP("Less");
32REGISTER_NO_GRADIENT_OP("LessEqual");
33REGISTER_NO_GRADIENT_OP("Greater");
34REGISTER_NO_GRADIENT_OP("GreaterEqual");
35REGISTER_NO_GRADIENT_OP("Equal");
36REGISTER_NO_GRADIENT_OP("ApproximateEqual");
37REGISTER_NO_GRADIENT_OP("NotEqual");
38REGISTER_NO_GRADIENT_OP("LogicalAnd");
39REGISTER_NO_GRADIENT_OP("LogicalOr");
40REGISTER_NO_GRADIENT_OP("LogicalNot");
41
42// Conjugate helper function returns the conjugate of an Output if it
43// is complex valued.
44Output ConjugateHelper(const Scope& scope, const Output& out) {
45  DataType dtype = out.type();
46  if (dtype == DT_COMPLEX64 || dtype == DT_COMPLEX128) {
47    return Conj(scope, out);
48  } else {
49    return out;
50  }
51}
52
53// TODO(andydavis) Add control dependencies to gradient functions (as needed).
54
55Status AbsGrad(const Scope& scope, const Operation& op,
56               const std::vector<Output>& grad_inputs,
57               std::vector<Output>* grad_outputs) {
58  // dx = dy * sign(x)
59  grad_outputs->push_back(Mul(scope, grad_inputs[0], Sign(scope, op.input(0))));
60  return scope.status();
61}
62REGISTER_GRADIENT_OP("Abs", AbsGrad);
63
64Status NegGrad(const Scope& scope, const Operation& op,
65               const std::vector<Output>& grad_inputs,
66               std::vector<Output>* grad_outputs) {
67  // dx = -dy;
68  grad_outputs->push_back(Neg(scope, grad_inputs[0]));
69  return scope.status();
70}
71REGISTER_GRADIENT_OP("Neg", NegGrad);
72
73Status InvGrad(const Scope& scope, const Operation& op,
74               const std::vector<Output>& grad_inputs,
75               std::vector<Output>* grad_outputs) {
76  // Use the built-in operator.
77  grad_outputs->push_back(
78      internal::ReciprocalGrad(scope, op.output(0), grad_inputs[0]));
79  return scope.status();
80}
81REGISTER_GRADIENT_OP("Inv", InvGrad);
82REGISTER_GRADIENT_OP("Reciprocal", InvGrad);
83
84Status SquareGrad(const Scope& scope, const Operation& op,
85                  const std::vector<Output>& grad_inputs,
86                  std::vector<Output>* grad_outputs) {
87  // dy/dx = (2 * x)
88  auto two = Cast(scope, Const(scope, 2), op.input(0).type());
89  auto dydx = Mul(scope, two, op.input(0));
90  // grad(x) = grad(y) * conj(dy/dx)
91  grad_outputs->push_back(
92      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
93  return scope.status();
94}
95REGISTER_GRADIENT_OP("Square", SquareGrad);
96
97Status SqrtGrad(const Scope& scope, const Operation& op,
98                const std::vector<Output>& grad_inputs,
99                std::vector<Output>* grad_outputs) {
100  // Use the built-in operator.
101  grad_outputs->push_back(
102      internal::SqrtGrad(scope, op.output(0), grad_inputs[0]));
103  return scope.status();
104}
105REGISTER_GRADIENT_OP("Sqrt", SqrtGrad);
106
107Status RsqrtGrad(const Scope& scope, const Operation& op,
108                 const std::vector<Output>& grad_inputs,
109                 std::vector<Output>* grad_outputs) {
110  // Use the built-in operator.
111  grad_outputs->push_back(
112      internal::RsqrtGrad(scope, op.output(0), grad_inputs[0]));
113  return scope.status();
114}
115REGISTER_GRADIENT_OP("Rsqrt", RsqrtGrad);
116
117Status ExpGrad(const Scope& scope, const Operation& op,
118               const std::vector<Output>& grad_inputs,
119               std::vector<Output>* grad_outputs) {
120  // dy/dx = exp(x) = y
121  // grad(x) = grad(y) * conj(dy/dx)
122  //         = grad(y) * conj(y)
123  grad_outputs->push_back(
124      Mul(scope, grad_inputs[0], ConjugateHelper(scope, op.output(0))));
125  return scope.status();
126}
127REGISTER_GRADIENT_OP("Exp", ExpGrad);
128
129Status Expm1Grad(const Scope& scope, const Operation& op,
130                 const std::vector<Output>& grad_inputs,
131                 std::vector<Output>* grad_outputs) {
132  // y = expm1(x)
133  // dy/dx = exp(x)
134  auto dydx = Exp(scope, op.input(0));
135  // grad(x) = grad(y) * conj(dy/dx)
136  grad_outputs->push_back(
137      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
138  return scope.status();
139}
140REGISTER_GRADIENT_OP("Expm1", Expm1Grad);
141
142Status LogGrad(const Scope& scope, const Operation& op,
143               const std::vector<Output>& grad_inputs,
144               std::vector<Output>* grad_outputs) {
145  // y = log(x)
146  // dy/dx = 1 / x
147  auto dydx = Reciprocal(scope, op.input(0));
148  // grad(x) = grad(y) * conj(dy/dx)
149  grad_outputs->push_back(
150      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
151  return scope.status();
152}
153REGISTER_GRADIENT_OP("Log", LogGrad);
154
155Status Log1pGrad(const Scope& scope, const Operation& op,
156                 const std::vector<Output>& grad_inputs,
157                 std::vector<Output>* grad_outputs) {
158  // y = log1p(x)
159  // dy/dx = 1 / (1 + x)
160  auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
161  auto dydx = Reciprocal(scope, Add(scope, one, op.input(0)));
162  // grad(x) = grad(y) * conj(dy/dx)
163  grad_outputs->push_back(
164      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
165  return scope.status();
166}
167REGISTER_GRADIENT_OP("Log1p", Log1pGrad);
168
169Status SinhGrad(const Scope& scope, const Operation& op,
170                const std::vector<Output>& grad_inputs,
171                std::vector<Output>* grad_outputs) {
172  // y = sinh(x)
173  // dy/dx = cosh(x)
174  auto dydx = Cosh(scope, op.input(0));
175  // grad(x) = grad(y) * conj(dy/dx)
176  grad_outputs->push_back(
177      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
178  return scope.status();
179}
180REGISTER_GRADIENT_OP("Sinh", SinhGrad);
181
182Status CoshGrad(const Scope& scope, const Operation& op,
183                const std::vector<Output>& grad_inputs,
184                std::vector<Output>* grad_outputs) {
185  // y = cosh(x)
186  // dy/dx = sinh(x)
187  auto dydx = Sinh(scope, op.input(0));
188  // grad(x) = grad(y) * conj(dy/dx)
189  grad_outputs->push_back(
190      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
191  return scope.status();
192}
193REGISTER_GRADIENT_OP("Cosh", CoshGrad);
194
195Status TanhGrad(const Scope& scope, const Operation& op,
196                const std::vector<Output>& grad_inputs,
197                std::vector<Output>* grad_outputs) {
198  // Use the built-in operator.
199  // Note that the built-in operator does not return the conjugate of
200  // the gradient.
201  auto grad = grad_inputs[0];
202  // Optimization to avoid calculating conj(y) until the gradient is
203  // evaluated.
204  Scope grad_scope = scope.WithControlDependencies(grad);
205  auto y = ConjugateHelper(grad_scope, op.output(0));
206  grad_outputs->push_back(internal::TanhGrad(grad_scope, y, grad));
207  return grad_scope.status();
208}
209REGISTER_GRADIENT_OP("Tanh", TanhGrad);
210
211Status AsinhGrad(const Scope& scope, const Operation& op,
212                 const std::vector<Output>& grad_inputs,
213                 std::vector<Output>* grad_outputs) {
214  // y = asinh(x)
215  // dy/dx = 1 / cosh(y)
216  auto dydx = Reciprocal(scope, Cosh(scope, op.output(0)));
217  // grad(x) = grad(y) * conj(dy/dx)
218  grad_outputs->push_back(
219      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
220  return scope.status();
221}
222REGISTER_GRADIENT_OP("Asinh", AsinhGrad);
223
224Status AcoshGrad(const Scope& scope, const Operation& op,
225                 const std::vector<Output>& grad_inputs,
226                 std::vector<Output>* grad_outputs) {
227  // y = acosh(x)
228  // dy/dx = 1 / sinh(y)
229  auto dydx = Reciprocal(scope, Sinh(scope, op.output(0)));
230  // grad(x) = grad(y) * conj(dy/dx)
231  grad_outputs->push_back(
232      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
233  return scope.status();
234}
235REGISTER_GRADIENT_OP("Acosh", AcoshGrad);
236
237Status AtanhGrad(const Scope& scope, const Operation& op,
238                 const std::vector<Output>& grad_inputs,
239                 std::vector<Output>* grad_outputs) {
240  // y = atanh(x)
241  // dy/dx = 1 / (1 - x^2)
242  auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
243  auto dydx = Reciprocal(scope, Sub(scope, one, Square(scope, op.input(0))));
244  // grad(x) = grad(y) * conj(dy/dx)
245  grad_outputs->push_back(
246      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
247  return scope.status();
248}
249REGISTER_GRADIENT_OP("Atanh", AtanhGrad);
250
251Status SigmoidGrad(const Scope& scope, const Operation& op,
252                   const std::vector<Output>& grad_inputs,
253                   std::vector<Output>* grad_outputs) {
254  // Use the built-in operator.
255  // Note that the built-in operator does not return the conjugate of
256  // the gradient.
257  auto grad = grad_inputs[0];
258  // Optimization to avoid calculating conj(y) until the gradient is
259  // evaluated.
260  Scope grad_scope = scope.WithControlDependencies(grad);
261  auto y = ConjugateHelper(grad_scope, op.output(0));
262  grad_outputs->push_back(internal::SigmoidGrad(grad_scope, y, grad));
263  return grad_scope.status();
264}
265REGISTER_GRADIENT_OP("Sigmoid", SigmoidGrad);
266
267Status SignGrad(const Scope& scope, const Operation& op,
268                const std::vector<Output>& grad_inputs,
269                std::vector<Output>* grad_outputs) {
270  auto shape = Shape(scope, op.input(0));
271  auto zero = Cast(scope, Const(scope, 0.0), op.input(0).type());
272  auto dx = Fill(scope, shape, zero);
273  grad_outputs->push_back(dx);
274  return scope.status();
275}
276REGISTER_GRADIENT_OP("Sign", SignGrad);
277
278Status SinGrad(const Scope& scope, const Operation& op,
279               const std::vector<Output>& grad_inputs,
280               std::vector<Output>* grad_outputs) {
281  // y = sin(x)
282  // dy/dx = cos(x)
283  auto dydx = Cos(scope, op.input(0));
284  // grad(x) = grad(y) * conj(dy/dx)
285  grad_outputs->push_back(
286      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
287  return scope.status();
288}
289REGISTER_GRADIENT_OP("Sin", SinGrad);
290
291Status CosGrad(const Scope& scope, const Operation& op,
292               const std::vector<Output>& grad_inputs,
293               std::vector<Output>* grad_outputs) {
294  // y = cos(x)
295  // dy/dx = -sin(x)
296  auto dydx = Neg(scope, Sin(scope, op.input(0)));
297  // grad(x) = grad(y) * conj(dy/dx)
298  grad_outputs->push_back(
299      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
300  return scope.status();
301}
302REGISTER_GRADIENT_OP("Cos", CosGrad);
303
304Status AsinGrad(const Scope& scope, const Operation& op,
305                const std::vector<Output>& grad_inputs,
306                std::vector<Output>* grad_outputs) {
307  // y = asin(x)
308  // dy/dx = 1 / sqrt(1 - x^2)
309  auto x2 = Square(scope, op.input(0));
310  auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
311  auto dydx = Reciprocal(scope, Sqrt(scope, Sub(scope, one, x2)));
312  // grad(x) = grad(y) * conj(dy/dx)
313  auto dx = Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx));
314  grad_outputs->push_back(dx);
315  return scope.status();
316}
317REGISTER_GRADIENT_OP("Asin", AsinGrad);
318
319Status AcosGrad(const Scope& scope, const Operation& op,
320                const std::vector<Output>& grad_inputs,
321                std::vector<Output>* grad_outputs) {
322  // y = acos(x)
323  // dy/dx = - 1 / (1 - x * x)^1/2
324  // dx = dy * (- 1 / (1 - x * x)^1/2)
325  auto x2 = Square(scope, op.input(0));
326  auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
327  auto dydx = Neg(scope, Reciprocal(scope, Sqrt(scope, Sub(scope, one, x2))));
328  auto dx = Mul(scope, grad_inputs[0], dydx);
329  grad_outputs->push_back(dx);
330  return scope.status();
331}
332REGISTER_GRADIENT_OP("Acos", AcosGrad);
333
334Status TanGrad(const Scope& scope, const Operation& op,
335               const std::vector<Output>& grad_inputs,
336               std::vector<Output>* grad_outputs) {
337  // y = tan(x)
338  // dy/dx = sec(x)^2 = 1 / cos(x)^2
339  auto dydx = Square(scope, Reciprocal(scope, Cos(scope, op.input(0))));
340  // grad(x) = grad(y) * conj(dy/dx)
341  auto dx = Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx));
342  grad_outputs->push_back(dx);
343  return scope.status();
344}
345REGISTER_GRADIENT_OP("Tan", TanGrad);
346
347Status AtanGrad(const Scope& scope, const Operation& op,
348                const std::vector<Output>& grad_inputs,
349                std::vector<Output>* grad_outputs) {
350  // y = arctan(x)
351  // dy/dx = 1 / (1 + x^2)
352  // dx = dy * (1 / (1 + x^2)
353  auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
354  auto dydx = Reciprocal(scope, Add(scope, one, Square(scope, op.input(0))));
355  auto dx = Mul(scope, grad_inputs[0], dydx);
356  grad_outputs->push_back(dx);
357  return scope.status();
358}
359REGISTER_GRADIENT_OP("Atan", AtanGrad);
360
361// BinaryGradCommon handles the setup for binary ops that broadcast
362// their inputs.
363Status BinaryGradCommon(const Scope& scope, const Operation& op,
364                        std::vector<Output>* grad_outputs, const Output& gx_1,
365                        const Output& gx_2) {
366  auto sx_1 = Shape(scope, op.input(0));
367  auto sx_2 = Shape(scope, op.input(1));
368  auto rx = internal::BroadcastGradientArgs(scope, sx_1, sx_2);
369  auto dx_1 = Reshape(scope, Sum(scope, gx_1, rx.r0), sx_1);
370  auto dx_2 = Reshape(scope, Sum(scope, gx_2, rx.r1), sx_2);
371  grad_outputs->push_back(dx_1);
372  grad_outputs->push_back(dx_2);
373  return scope.status();
374}
375
376Status AddGrad(const Scope& scope, const Operation& op,
377               const std::vector<Output>& grad_inputs,
378               std::vector<Output>* grad_outputs) {
379  // y = x_1 + x_2
380  // dy/dx_1 = dy/dx_2 = 1
381  auto gx_1 = Identity(scope, grad_inputs[0]);
382  auto gx_2 = Identity(scope, grad_inputs[0]);
383  return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
384}
385REGISTER_GRADIENT_OP("Add", AddGrad);
386
387Status SubGrad(const Scope& scope, const Operation& op,
388               const std::vector<Output>& grad_inputs,
389               std::vector<Output>* grad_outputs) {
390  // y = x_1 - x_2
391  // dy/dx_1 = 1
392  // dy/dx_2 = -1
393  auto gx_1 = Identity(scope, grad_inputs[0]);
394  auto gx_2 = Neg(scope, grad_inputs[0]);
395  return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
396}
397REGISTER_GRADIENT_OP("Sub", SubGrad);
398
399Status MulGrad(const Scope& scope, const Operation& op,
400               const std::vector<Output>& grad_inputs,
401               std::vector<Output>* grad_outputs) {
402  auto x_1 = ConjugateHelper(scope, op.input(0));
403  auto x_2 = ConjugateHelper(scope, op.input(1));
404  // y = x_1 * x_2
405  // dy/dx_1 = x_2
406  // dy/dx_2 = x_1
407  auto gx_1 = Mul(scope, grad_inputs[0], x_2);
408  auto gx_2 = Mul(scope, grad_inputs[0], x_1);
409  return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
410}
411REGISTER_GRADIENT_OP("Mul", MulGrad);
412
413Status DivGrad(const Scope& scope, const Operation& op,
414               const std::vector<Output>& grad_inputs,
415               std::vector<Output>* grad_outputs) {
416  auto x_1 = ConjugateHelper(scope, op.input(0));
417  auto x_2 = ConjugateHelper(scope, op.input(1));
418  // y = x_1 / x_2
419  // dy/dx_1 = 1/x_2
420  // dy/dx_2 = -x_1/x_2^2
421  auto gx_1 = Div(scope, grad_inputs[0], x_2);
422  auto gx_2 = Mul(scope, grad_inputs[0],
423                  Div(scope, Div(scope, Neg(scope, x_1), x_2), x_2));
424  return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
425}
426REGISTER_GRADIENT_OP("Div", DivGrad);
427
428Status RealDivGrad(const Scope& scope, const Operation& op,
429                   const std::vector<Output>& grad_inputs,
430                   std::vector<Output>* grad_outputs) {
431  auto x_1 = ConjugateHelper(scope, op.input(0));
432  auto x_2 = ConjugateHelper(scope, op.input(1));
433  // y = x_1 / x_2
434  // dy/dx_1 = 1/x_2
435  // dy/dx_2 = -x_1/x_2^2
436  auto gx_1 = RealDiv(scope, grad_inputs[0], x_2);
437  auto gx_2 = Mul(scope, grad_inputs[0],
438                  RealDiv(scope, RealDiv(scope, Neg(scope, x_1), x_2), x_2));
439  return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
440}
441REGISTER_GRADIENT_OP("RealDiv", RealDivGrad);
442
443Status SquaredDifferenceGrad(const Scope& scope, const Operation& op,
444                             const std::vector<Output>& grad_inputs,
445                             std::vector<Output>* grad_outputs) {
446  auto x_1 = ConjugateHelper(scope, op.input(0));
447  auto x_2 = ConjugateHelper(scope, op.input(1));
448  // y = (x_1 - x_2)^2
449  // dy/dx_1 = 2 * (x_1 - x_2)
450  // dy/dx_2 = -2 * (x_1 - x_2)
451  auto two = Cast(scope, Const(scope, 2), grad_inputs[0].type());
452  auto gx_1 = Mul(scope, grad_inputs[0], Mul(scope, two, Sub(scope, x_1, x_2)));
453  auto gx_2 = Neg(scope, gx_1);
454  return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
455}
456REGISTER_GRADIENT_OP("SquaredDifference", SquaredDifferenceGrad);
457
458Status AddNGrad(const Scope& scope, const Operation& op,
459                const std::vector<Output>& grad_inputs,
460                std::vector<Output>* grad_outputs) {
461  // AddN doesn't support broadcasting, so all the inputs must be the
462  // same shape.
463  // Note:
464  // dy/dx_k = d(x_1 + x_2 + ... + x_n)/dx_k = 1 for all x_k
465  // hence dx_k = dy for all x_k
466  // So the gradient for AddN just transfers the incoming gradient to
467  // all outgoing gradients.
468  auto incoming = Identity(scope, grad_inputs[0]);
469  for (int32 i = 0; i < op.num_inputs(); ++i) {
470    grad_outputs->push_back(incoming);
471  }
472  return scope.status();
473}
474REGISTER_GRADIENT_OP("AddN", AddNGrad);
475
476// MaximumMinimumGradCommon adds shared ops to calculate gradients for
477// the binary Maximum and Minimum ops.
478Status MaximumMinimumGradCommon(const Scope& scope, const Operation& op,
479                                const std::vector<Output>& grad_inputs,
480                                std::vector<Output>* grad_outputs,
481                                const Output& comparator) {
482  // comparator is a boolean tensor, with
483  // y = x_1 at points where comparator is true, and x_2 otherwise
484  // Therefore
485  // dy/dx_1 = 1 where comparator is true, and 0 otherwise.
486  // dy/dx_2 = 0 where comparator is true, and 1 otherwise.
487  auto grad = grad_inputs[0];
488  auto zeros = ZerosLike(scope, grad);
489  auto gx_1 = Where3(scope, comparator, grad, zeros);
490  auto gx_2 = Where3(scope, comparator, zeros, grad);
491  return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
492}
493
494Status MaximumGrad(const Scope& scope, const Operation& op,
495                   const std::vector<Output>& grad_inputs,
496                   std::vector<Output>* grad_outputs) {
497  auto comparator = GreaterEqual(scope, op.input(0), op.input(1));
498  return MaximumMinimumGradCommon(scope, op, grad_inputs, grad_outputs,
499                                  comparator);
500}
501REGISTER_GRADIENT_OP("Maximum", MaximumGrad);
502
503Status MinimumGrad(const Scope& scope, const Operation& op,
504                   const std::vector<Output>& grad_inputs,
505                   std::vector<Output>* grad_outputs) {
506  auto comparator = LessEqual(scope, op.input(0), op.input(1));
507  return MaximumMinimumGradCommon(scope, op, grad_inputs, grad_outputs,
508                                  comparator);
509}
510REGISTER_GRADIENT_OP("Minimum", MinimumGrad);
511
512Status RealGrad(const Scope& scope, const Operation& op,
513                const std::vector<Output>& grad_inputs,
514                std::vector<Output>* grad_outputs) {
515  auto zero = Cast(scope, Const(scope, 0.0), op.output(0).type());
516  auto dx = Complex(scope, grad_inputs[0], zero);
517  grad_outputs->push_back(dx);
518  return scope.status();
519}
520REGISTER_GRADIENT_OP("Real", RealGrad);
521
522Status ImagGrad(const Scope& scope, const Operation& op,
523                const std::vector<Output>& grad_inputs,
524                std::vector<Output>* grad_outputs) {
525  auto zero = Cast(scope, Const(scope, 0.0), op.output(0).type());
526  auto dx = Complex(scope, zero, grad_inputs[0]);
527  grad_outputs->push_back(dx);
528  return scope.status();
529}
530REGISTER_GRADIENT_OP("Imag", ImagGrad);
531
532Status ComplexGrad(const Scope& scope, const Operation& op,
533                   const std::vector<Output>& grad_inputs,
534                   std::vector<Output>* grad_outputs) {
535  auto gx_1 = Real(scope, grad_inputs[0]);
536  auto gx_2 = Imag(scope, grad_inputs[0]);
537  return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
538}
539REGISTER_GRADIENT_OP("Complex", ComplexGrad);
540
541Status AngleGrad(const Scope& scope, const Operation& op,
542                 const std::vector<Output>& grad_inputs,
543                 std::vector<Output>* grad_outputs) {
544  // y = Angle(x)
545  // dx = -dy / (Im(x) + iRe(x)) = -dy * z
546  auto re = Real(scope, op.input(0));
547  auto im = Imag(scope, op.input(0));
548  auto z_inv = Reciprocal(scope, Complex(scope, im, re));
549  auto zero = Cast(scope, Const(scope, 0), grad_inputs[0].type());
550  auto grad = Complex(scope, grad_inputs[0], zero);
551  auto dx = Neg(scope, Mul(scope, grad, z_inv));
552  grad_outputs->push_back(dx);
553  return scope.status();
554}
555REGISTER_GRADIENT_OP("Angle", AngleGrad);
556
557Status ConjGrad(const Scope& scope, const Operation& op,
558                const std::vector<Output>& grad_inputs,
559                std::vector<Output>* grad_outputs) {
560  grad_outputs->push_back(Conj(scope, grad_inputs[0]));
561  return scope.status();
562}
563REGISTER_GRADIENT_OP("Conj", ConjGrad);
564
565// Integer division x / y, assuming x and y >=0, but treats x/0 = x
566Output SafeDivHelper(const Scope& scope, const Output& x, const Output& y) {
567  return Div(scope, x, Maximum(scope, y, Const(scope, 1)));
568}
569
570// Helper function for reduction ops.
571//
572// input_shape: 1-D Tensor, the shape of the Tensor being reduced.
573// axes: 1-D Tensor, the reduction axes.
574//   Note that the reduction indices are in the range
575//   -rank(input_shape), rank(input_shape)
576// returns a 1-D Tensor, the output shape as if keep_dims were set to True.
577Output ReducedShapeHelper(const Scope& scope, const Output& input_shape,
578                          const Output& reduction_axes) {
579  auto zero = Const(scope, 0);
580  auto one = Const(scope, 1);
581
582  // Running example in comments
583  // input_shape = [2, 3, 5, 7]
584  // axes = [1, 2]
585  // The result (a shape after a reduction with keep_dims=True)
586  // [2, 1, 1, 7]
587  //
588  // We can treat each entry in axes as an index into input_shape that
589  // should be replaced by 1.
590  // We use DynamicStitch to do this.
591
592  // input_rank = 4
593  auto input_rank = Size(scope, input_shape);
594
595  // Normalize any negative indices in the reduction_axes to positive
596  // values.
597  auto axes = Mod(scope, Add(scope, reduction_axes, input_rank), input_rank);
598
599  // This [0..input_rank) range of integers is used in DynamicStitch to
600  // first copy input_shape to the result.
601  // input_rank_range = [0, 1, 2, 3]
602  auto input_rank_range = Range(scope, zero, input_rank, one);
603
604  // A 1-filled tensor with the same shape as axes. DynamicStitch will
605  // merge these 1s (using axes for indices) to the correct
606  // position in the result.
607  // axes_ones = [1, 1]
608  auto axes_ones = OnesLike(scope, axes);
609
610  // using DynamicStitch:
611  // indices = { input_rank_range, axes }
612  //         = { [0, 1, 2, 3], [1, 2] }
613  // data = { input_shape, axes_ones }
614  //      = { [2, 3, 5, 7], [1, 1] }
615  // The input_rank_range entry in indices first replicates the
616  // input_shape to the result.
617  // The axes entry in indices then moves a 1 to each of its entries,
618  // resulting in
619  // [2, 1, 1, 7]
620  std::vector<Output> indices = {input_rank_range, axes};
621  std::vector<Output> data = {input_shape, axes_ones};
622  return DynamicStitch(scope, indices, data);
623}
624
625// SumGradHelper returns the gradient for the Sum operator, and is used
626// by SumGrad and MeanGrad.
627Output SumGradHelper(const Scope& scope, const Operation& op,
628                     const std::vector<Output>& grad_inputs) {
629  // The partial derivative for any input along a "reduced" dimension
630  // is just 1, so we only need replicate the output gradient on such a
631  // dimension to its "expanded" shape.
632  // Running example:
633  // input is
634  // [[a, b, c],
635  //  [d, e, f]]
636  // reduction_indices = [1]
637  // Sum = [a + b + c, d + e + f]
638  // if the gradient is [g1, g2]
639  // We want the propagated gradient to be
640  // [[g1, g1, g1],
641  //  [g2, g2, g2]]
642
643  // input_shape = [2, 3]
644  auto input_shape = Shape(scope, op.input(0));
645
646  // output_shape_kept_dims = [2, 1]
647  auto output_shape_kept_dims =
648      ReducedShapeHelper(scope, input_shape, op.input(1));
649
650  // This step "flips" any 1s with values from the input_shape, and
651  // replaces remaining entries with 1. This creates a shape that
652  // shows how much each dimension in the incoming gradient should be
653  // replicated.
654  // tile_scaling = [1, 3]
655  auto tile_scaling = SafeDivHelper(scope, input_shape, output_shape_kept_dims);
656
657  // grad = [[g1], [g2]]
658  auto grad = Reshape(scope, grad_inputs[0], output_shape_kept_dims);
659
660  // tile(grad, tile_scaling) = [[g1, g1, g1], [g2, g2, g2]]
661  return Tile(scope, grad, tile_scaling);
662}
663
664Status SumGrad(const Scope& scope, const Operation& op,
665               const std::vector<Output>& grad_inputs,
666               std::vector<Output>* grad_outputs) {
667  grad_outputs->push_back(SumGradHelper(scope, op, grad_inputs));
668
669  // Stop propagation along reduction_indices
670  grad_outputs->push_back(NoGradient());
671  return scope.status();
672}
673REGISTER_GRADIENT_OP("Sum", SumGrad);
674
675Status MeanGrad(const Scope& scope, const Operation& op,
676                const std::vector<Output>& grad_inputs,
677                std::vector<Output>* grad_outputs) {
678  // The Mean gradient is just like the Sum gradient, except that
679  // all gradients are also divided by the size of reduced groups.
680  auto sum_grad = SumGradHelper(scope, op, grad_inputs);
681
682  // The product of all entries in a tensor's shape is the total
683  // number of entries in the tensor. This step calculates
684  // n_input_entries/n_output_entries
685  // = group_size
686  auto input_shape = Shape(scope, op.input(0));
687  auto output_shape = Shape(scope, op.output(0));
688  auto zero = Const(scope, 0);
689  auto group_size = SafeDivHelper(scope, Prod(scope, input_shape, zero),
690                                  Prod(scope, output_shape, zero));
691
692  // propagate sum_grad/group_size
693  grad_outputs->push_back(
694      Div(scope, sum_grad, Cast(scope, group_size, sum_grad.type())));
695
696  // Stop propagation along reduction_indices
697  grad_outputs->push_back(NoGradient());
698  return scope.status();
699}
700REGISTER_GRADIENT_OP("Mean", MeanGrad);
701
702Status ErfGrad(const Scope& scope, const Operation& op,
703               const std::vector<Output>& grad_inputs,
704               std::vector<Output>* grad_outputs) {
705  auto grad = grad_inputs[0];
706  auto two_over_root_pi = Cast(scope, Const(scope, 2 / std::sqrt(M_PI)),
707                               grad.type());
708  Scope grad_scope = scope.WithControlDependencies(grad);
709  auto x = ConjugateHelper(grad_scope, op.input(0));
710  // grad * 2/sqrt(pi) * exp(-x**2)
711  auto dx = Mul(grad_scope,
712                Mul(grad_scope, grad, two_over_root_pi),
713                Exp(grad_scope, Neg(grad_scope, Square(grad_scope, x))));
714  grad_outputs->push_back(dx);
715  return grad_scope.status();
716}
717REGISTER_GRADIENT_OP("Erf", ErfGrad);
718
719Status LgammaGrad(const Scope& scope, const Operation& op,
720                  const std::vector<Output>& grad_inputs,
721                  std::vector<Output>* grad_outputs) {
722  auto grad = grad_inputs[0];
723  Scope grad_scope = scope.WithControlDependencies(grad);
724  auto x = ConjugateHelper(grad_scope, op.input(0));
725  auto dx = Mul(grad_scope, grad, Digamma(grad_scope, x));
726  grad_outputs->push_back(dx);
727  return grad_scope.status();
728}
729REGISTER_GRADIENT_OP("Lgamma", LgammaGrad);
730
731Status MinOrMaxGrad(const Scope& scope, const Operation& op,
732                    const std::vector<Output>& grad_inputs,
733                    std::vector<Output>* grad_outputs) {
734  // The partial derivative for any input along a "reduced" dimension
735  // is 1 when it is the min (or max) and 0 everywhere else. So the
736  // gradient calculation is identical for both operators.
737  //
738  // There's a special case for propagating gradients when there are
739  // multiple minima (or maxima) - we choose to divide the gradient
740  // equally among all matching inputs.
741  //
742  // Please note this comment
743  // https://github.com/tensorflow/tensorflow/issues/4886#issuecomment-256836063
744  // for details.
745
746  // Running example:
747  // input: [[5, 5, 5],
748  //         [1, 2, -3]]
749  // reduction_indices: [1]
750  auto input = op.input(0);
751  auto reduction_indices = op.input(1);
752
753  // [2, 3]
754  auto input_shape = Shape(scope, input);
755
756  // [2, 1]
757  auto output_shape_kept_dims =
758      ReducedShapeHelper(scope, input_shape, reduction_indices);
759
760  // for op=min (say)
761  // output = [5, -3]
762  // y = [[5],
763  //      [-3]]
764  auto y = Reshape(scope, op.output(0), output_shape_kept_dims);
765
766  // reshape([g1, g2], [2, 1]) = [[g1],
767  //                              [g2]]
768  auto grad = Reshape(scope, grad_inputs[0], output_shape_kept_dims);
769
770  // indicators = equal(y, input)
771  //  = equal([[5],   [[5, 5, 5],
772  //           [-3]],  [1, 2, -3]])
773  //  = [[1, 1, 1],
774  //     [0, 0, 1]]
775  auto indicators = Cast(scope, Equal(scope, y, input), grad_inputs[0].type());
776
777  // [[3],
778  //  [1]]
779  auto num_selected = Reshape(scope, Sum(scope, indicators, reduction_indices),
780                              output_shape_kept_dims);
781
782  // [[1/3, 1/3, 1/3],
783  //  [0, 0, 1]]
784  auto scale = Div(scope, indicators, num_selected);
785
786  // [[g1/3, g1/3, g1/3],
787  //  [0, 0, g2]]
788  grad_outputs->push_back(Mul(scope, scale, grad));
789
790  // Stop propagation along reduction_indices
791  grad_outputs->push_back(NoGradient());
792  return scope.status();
793}
794REGISTER_GRADIENT_OP("Min", MinOrMaxGrad);
795REGISTER_GRADIENT_OP("Max", MinOrMaxGrad);
796
797// MatMulGrad helper function used to compute two MatMul operations
798// based on input matrix transposition combinations.
799Status MatMulGradHelper(const Scope& scope, const bool is_batch,
800                        const Output& x0, const bool adj_x0, const Output& x1,
801                        const bool adj_x1, const Output& y0, const bool adj_y0,
802                        const Output& y1, const bool adj_y1,
803                        std::vector<Output>* grad_outputs) {
804  if (is_batch == false) {
805    auto dx =
806        MatMul(scope, x0, x1, MatMul::TransposeA(adj_x0).TransposeB(adj_x1));
807    grad_outputs->push_back(dx);
808    auto dy =
809        MatMul(scope, y0, y1, MatMul::TransposeA(adj_y0).TransposeB(adj_y1));
810    grad_outputs->push_back(dy);
811  } else {
812    auto dx =
813        BatchMatMul(scope, x0, x1, BatchMatMul::AdjX(adj_x0).AdjY(adj_x1));
814    grad_outputs->push_back(dx);
815    auto dy =
816        BatchMatMul(scope, y0, y1, BatchMatMul::AdjX(adj_y0).AdjY(adj_y1));
817    grad_outputs->push_back(dy);
818  }
819  return scope.status();
820}
821
822// MatMulGrad common used to read and check node attr state, and determine
823// proper MatMul products for gradients based on input matrix transposition
824// combinations.
825Status MatMulGradCommon(const Scope& scope, const Operation& op,
826                        const bool is_batch,
827                        const std::vector<Output>& grad_inputs,
828                        const string& attr_adj_x, const string& attr_adj_y,
829                        std::vector<Output>* grad_outputs) {
830  auto a = op.input(0);
831  auto b = op.input(1);
832  // Use conjugate of the inputs for MatMul
833  if (is_batch == false) {
834    a = ConjugateHelper(scope, a);
835    b = ConjugateHelper(scope, b);
836  }
837  auto product = op.output(0);
838
839  bool ta;
840  bool tb;
841  TF_RETURN_IF_ERROR(GetNodeAttr(product.node()->attrs(), attr_adj_x, &ta));
842  TF_RETURN_IF_ERROR(GetNodeAttr(product.node()->attrs(), attr_adj_y, &tb));
843
844  if (!ta && !tb) {
845    return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, b, true, a,
846                            true, grad_inputs[0], false, grad_outputs);
847  } else if (!ta && tb) {
848    return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, b, false,
849                            grad_inputs[0], true, a, false, grad_outputs);
850  } else if (ta && !tb) {
851    return MatMulGradHelper(scope, is_batch, b, false, grad_inputs[0], true, a,
852                            false, grad_inputs[0], false, grad_outputs);
853  }
854  return MatMulGradHelper(scope, is_batch, b, true, grad_inputs[0], true,
855                          grad_inputs[0], true, a, true, grad_outputs);
856}
857
858Status MatMulGrad(const Scope& scope, const Operation& op,
859                  const std::vector<Output>& grad_inputs,
860                  std::vector<Output>* grad_outputs) {
861  return MatMulGradCommon(scope, op, false, grad_inputs, "transpose_a",
862                          "transpose_b", grad_outputs);
863}
864REGISTER_GRADIENT_OP("MatMul", MatMulGrad);
865
866Status BatchMatMulGrad(const Scope& scope, const Operation& op,
867                       const std::vector<Output>& grad_inputs,
868                       std::vector<Output>* grad_outputs) {
869  return MatMulGradCommon(scope, op, true, grad_inputs, "adj_x", "adj_y",
870                          grad_outputs);
871}
872REGISTER_GRADIENT_OP("BatchMatMul", BatchMatMulGrad);
873
874}  // anonymous namespace
875}  // namespace ops
876}  // namespace tensorflow
877