math_grad.cc revision 20765b3e1ae3b718699592c98aa9805cb874b6d1
1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#define _USE_MATH_DEFINES
17#include <cmath>
18
19#include "tensorflow/cc/ops/array_ops_internal.h"
20#include "tensorflow/cc/ops/math_ops_internal.h"
21#include "tensorflow/cc/ops/standard_ops.h"
22
23#include "tensorflow/cc/framework/grad_op_registry.h"
24#include "tensorflow/cc/framework/gradients.h"
25
26namespace tensorflow {
27namespace ops {
28namespace {
29
30// Logical operations have no gradients.
31REGISTER_NO_GRADIENT_OP("Less");
32REGISTER_NO_GRADIENT_OP("LessEqual");
33REGISTER_NO_GRADIENT_OP("Greater");
34REGISTER_NO_GRADIENT_OP("GreaterEqual");
35REGISTER_NO_GRADIENT_OP("Equal");
36REGISTER_NO_GRADIENT_OP("ApproximateEqual");
37REGISTER_NO_GRADIENT_OP("NotEqual");
38REGISTER_NO_GRADIENT_OP("LogicalAnd");
39REGISTER_NO_GRADIENT_OP("LogicalOr");
40REGISTER_NO_GRADIENT_OP("LogicalNot");
41
42// Conjugate helper function returns the conjugate of an Output if it
43// is complex valued.
44Output ConjugateHelper(const Scope& scope, const Output& out) {
45  DataType dtype = out.type();
46  if (dtype == DT_COMPLEX64 || dtype == DT_COMPLEX128) {
47    return Conj(scope, out);
48  } else {
49    return out;
50  }
51}
52
53// TODO(andydavis) Add control dependencies to gradient functions (as needed).
54
55Status AbsGrad(const Scope& scope, const Operation& op,
56               const std::vector<Output>& grad_inputs,
57               std::vector<Output>* grad_outputs) {
58  // dx = dy * sign(x)
59  grad_outputs->push_back(Mul(scope, grad_inputs[0], Sign(scope, op.input(0))));
60  return scope.status();
61}
62REGISTER_GRADIENT_OP("Abs", AbsGrad);
63
64Status NegGrad(const Scope& scope, const Operation& op,
65               const std::vector<Output>& grad_inputs,
66               std::vector<Output>* grad_outputs) {
67  // dx = -dy;
68  grad_outputs->push_back(Neg(scope, grad_inputs[0]));
69  return scope.status();
70}
71REGISTER_GRADIENT_OP("Neg", NegGrad);
72
73Status InvGrad(const Scope& scope, const Operation& op,
74               const std::vector<Output>& grad_inputs,
75               std::vector<Output>* grad_outputs) {
76  // Use the built-in operator.
77  grad_outputs->push_back(
78      internal::ReciprocalGrad(scope, op.output(0), grad_inputs[0]));
79  return scope.status();
80}
81REGISTER_GRADIENT_OP("Inv", InvGrad);
82REGISTER_GRADIENT_OP("Reciprocal", InvGrad);
83
84Status SquareGrad(const Scope& scope, const Operation& op,
85                  const std::vector<Output>& grad_inputs,
86                  std::vector<Output>* grad_outputs) {
87  // dy/dx = (2 * x)
88  auto two = Cast(scope, Const(scope, 2), op.input(0).type());
89  auto dydx = Mul(scope, two, op.input(0));
90  // grad(x) = grad(y) * conj(dy/dx)
91  grad_outputs->push_back(
92      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
93  return scope.status();
94}
95REGISTER_GRADIENT_OP("Square", SquareGrad);
96
97Status SqrtGrad(const Scope& scope, const Operation& op,
98                const std::vector<Output>& grad_inputs,
99                std::vector<Output>* grad_outputs) {
100  // Use the built-in operator.
101  grad_outputs->push_back(
102      internal::SqrtGrad(scope, op.output(0), grad_inputs[0]));
103  return scope.status();
104}
105REGISTER_GRADIENT_OP("Sqrt", SqrtGrad);
106
107Status RsqrtGrad(const Scope& scope, const Operation& op,
108                 const std::vector<Output>& grad_inputs,
109                 std::vector<Output>* grad_outputs) {
110  // Use the built-in operator.
111  grad_outputs->push_back(
112      internal::RsqrtGrad(scope, op.output(0), grad_inputs[0]));
113  return scope.status();
114}
115REGISTER_GRADIENT_OP("Rsqrt", RsqrtGrad);
116
117Status ExpGrad(const Scope& scope, const Operation& op,
118               const std::vector<Output>& grad_inputs,
119               std::vector<Output>* grad_outputs) {
120  // dy/dx = exp(x) = y
121  // grad(x) = grad(y) * conj(dy/dx)
122  //         = grad(y) * conj(y)
123  grad_outputs->push_back(
124      Mul(scope, grad_inputs[0], ConjugateHelper(scope, op.output(0))));
125  return scope.status();
126}
127REGISTER_GRADIENT_OP("Exp", ExpGrad);
128
129Status Expm1Grad(const Scope& scope, const Operation& op,
130                 const std::vector<Output>& grad_inputs,
131                 std::vector<Output>* grad_outputs) {
132  // y = expm1(x)
133  // dy/dx = exp(x)
134  auto dydx = Exp(scope, op.input(0));
135  // grad(x) = grad(y) * conj(dy/dx)
136  grad_outputs->push_back(
137      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
138  return scope.status();
139}
140REGISTER_GRADIENT_OP("Expm1", Expm1Grad);
141
142Status LogGrad(const Scope& scope, const Operation& op,
143               const std::vector<Output>& grad_inputs,
144               std::vector<Output>* grad_outputs) {
145  // y = log(x)
146  // dy/dx = 1 / x
147  auto dydx = Reciprocal(scope, op.input(0));
148  // grad(x) = grad(y) * conj(dy/dx)
149  grad_outputs->push_back(
150      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
151  return scope.status();
152}
153REGISTER_GRADIENT_OP("Log", LogGrad);
154
155Status Log1pGrad(const Scope& scope, const Operation& op,
156                 const std::vector<Output>& grad_inputs,
157                 std::vector<Output>* grad_outputs) {
158  // y = log1p(x)
159  // dy/dx = 1 / (1 + x)
160  auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
161  auto dydx = Reciprocal(scope, Add(scope, one, op.input(0)));
162  // grad(x) = grad(y) * conj(dy/dx)
163  grad_outputs->push_back(
164      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
165  return scope.status();
166}
167REGISTER_GRADIENT_OP("Log1p", Log1pGrad);
168
169Status SinhGrad(const Scope& scope, const Operation& op,
170                const std::vector<Output>& grad_inputs,
171                std::vector<Output>* grad_outputs) {
172  // y = sinh(x)
173  // dy/dx = cosh(x)
174  auto dydx = Cosh(scope, op.input(0));
175  // grad(x) = grad(y) * conj(dy/dx)
176  grad_outputs->push_back(
177      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
178  return scope.status();
179}
180REGISTER_GRADIENT_OP("Sinh", SinhGrad);
181
182Status CoshGrad(const Scope& scope, const Operation& op,
183                const std::vector<Output>& grad_inputs,
184                std::vector<Output>* grad_outputs) {
185  // y = cosh(x)
186  // dy/dx = sinh(x)
187  auto dydx = Sinh(scope, op.input(0));
188  // grad(x) = grad(y) * conj(dy/dx)
189  grad_outputs->push_back(
190      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
191  return scope.status();
192}
193REGISTER_GRADIENT_OP("Cosh", CoshGrad);
194
195Status TanhGrad(const Scope& scope, const Operation& op,
196                const std::vector<Output>& grad_inputs,
197                std::vector<Output>* grad_outputs) {
198  // Use the built-in operator.
199  // Note that the built-in operator does not return the conjugate of
200  // the gradient.
201  auto grad = grad_inputs[0];
202  // Optimization to avoid calculating conj(y) until the gradient is
203  // evaluated.
204  Scope grad_scope = scope.WithControlDependencies(grad);
205  auto y = ConjugateHelper(grad_scope, op.output(0));
206  grad_outputs->push_back(internal::TanhGrad(grad_scope, y, grad));
207  return grad_scope.status();
208}
209REGISTER_GRADIENT_OP("Tanh", TanhGrad);
210
211Status AsinhGrad(const Scope& scope, const Operation& op,
212                 const std::vector<Output>& grad_inputs,
213                 std::vector<Output>* grad_outputs) {
214  // y = asinh(x)
215  // dy/dx = 1 / cosh(y)
216  auto dydx = Reciprocal(scope, Cosh(scope, op.output(0)));
217  // grad(x) = grad(y) * conj(dy/dx)
218  grad_outputs->push_back(
219      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
220  return scope.status();
221}
222REGISTER_GRADIENT_OP("Asinh", AsinhGrad);
223
224Status AcoshGrad(const Scope& scope, const Operation& op,
225                 const std::vector<Output>& grad_inputs,
226                 std::vector<Output>* grad_outputs) {
227  // y = acosh(x)
228  // dy/dx = 1 / sinh(y)
229  auto dydx = Reciprocal(scope, Sinh(scope, op.output(0)));
230  // grad(x) = grad(y) * conj(dy/dx)
231  grad_outputs->push_back(
232      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
233  return scope.status();
234}
235REGISTER_GRADIENT_OP("Acosh", AcoshGrad);
236
237Status AtanhGrad(const Scope& scope, const Operation& op,
238                 const std::vector<Output>& grad_inputs,
239                 std::vector<Output>* grad_outputs) {
240  // y = atanh(x)
241  // dy/dx = 1 / (1 - x^2)
242  auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
243  auto dydx = Reciprocal(scope, Sub(scope, one, Square(scope, op.input(0))));
244  // grad(x) = grad(y) * conj(dy/dx)
245  grad_outputs->push_back(
246      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
247  return scope.status();
248}
249REGISTER_GRADIENT_OP("Atanh", AtanhGrad);
250
251Status SigmoidGrad(const Scope& scope, const Operation& op,
252                   const std::vector<Output>& grad_inputs,
253                   std::vector<Output>* grad_outputs) {
254  // Use the built-in operator.
255  // Note that the built-in operator does not return the conjugate of
256  // the gradient.
257  auto grad = grad_inputs[0];
258  // Optimization to avoid calculating conj(y) until the gradient is
259  // evaluated.
260  Scope grad_scope = scope.WithControlDependencies(grad);
261  auto y = ConjugateHelper(grad_scope, op.output(0));
262  grad_outputs->push_back(internal::SigmoidGrad(grad_scope, y, grad));
263  return grad_scope.status();
264}
265REGISTER_GRADIENT_OP("Sigmoid", SigmoidGrad);
266
267Status SignGrad(const Scope& scope, const Operation& op,
268                const std::vector<Output>& grad_inputs,
269                std::vector<Output>* grad_outputs) {
270  auto shape = Shape(scope, op.input(0));
271  auto zero = Cast(scope, Const(scope, 0.0), op.input(0).type());
272  auto dx = Fill(scope, shape, zero);
273  grad_outputs->push_back(dx);
274  return scope.status();
275}
276REGISTER_GRADIENT_OP("Sign", SignGrad);
277
278Status SinGrad(const Scope& scope, const Operation& op,
279               const std::vector<Output>& grad_inputs,
280               std::vector<Output>* grad_outputs) {
281  // y = sin(x)
282  // dy/dx = cos(x)
283  auto dydx = Cos(scope, op.input(0));
284  // grad(x) = grad(y) * conj(dy/dx)
285  grad_outputs->push_back(
286      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
287  return scope.status();
288}
289REGISTER_GRADIENT_OP("Sin", SinGrad);
290
291Status CosGrad(const Scope& scope, const Operation& op,
292               const std::vector<Output>& grad_inputs,
293               std::vector<Output>* grad_outputs) {
294  // y = cos(x)
295  // dy/dx = -sin(x)
296  auto dydx = Neg(scope, Sin(scope, op.input(0)));
297  // grad(x) = grad(y) * conj(dy/dx)
298  grad_outputs->push_back(
299      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
300  return scope.status();
301}
302REGISTER_GRADIENT_OP("Cos", CosGrad);
303
304Status AsinGrad(const Scope& scope, const Operation& op,
305                const std::vector<Output>& grad_inputs,
306                std::vector<Output>* grad_outputs) {
307  // y = asin(x)
308  // dy/dx = 1 / sqrt(1 - x^2)
309  auto x2 = Square(scope, op.input(0));
310  auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
311  auto dydx = Reciprocal(scope, Sqrt(scope, Sub(scope, one, x2)));
312  // grad(x) = grad(y) * conj(dy/dx)
313  auto dx = Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx));
314  grad_outputs->push_back(dx);
315  return scope.status();
316}
317REGISTER_GRADIENT_OP("Asin", AsinGrad);
318
319Status AcosGrad(const Scope& scope, const Operation& op,
320                const std::vector<Output>& grad_inputs,
321                std::vector<Output>* grad_outputs) {
322  // y = acos(x)
323  // dy/dx = - 1 / (1 - x * x)^1/2
324  // dx = dy * (- 1 / (1 - x * x)^1/2)
325  auto x2 = Square(scope, op.input(0));
326  auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
327  auto dydx = Neg(scope, Reciprocal(scope, Sqrt(scope, Sub(scope, one, x2))));
328  auto dx = Mul(scope, grad_inputs[0], dydx);
329  grad_outputs->push_back(dx);
330  return scope.status();
331}
332REGISTER_GRADIENT_OP("Acos", AcosGrad);
333
334Status TanGrad(const Scope& scope, const Operation& op,
335               const std::vector<Output>& grad_inputs,
336               std::vector<Output>* grad_outputs) {
337  // y = tan(x)
338  // dy/dx = sec(x)^2 = 1 / cos(x)^2
339  auto dydx = Square(scope, Reciprocal(scope, Cos(scope, op.input(0))));
340  // grad(x) = grad(y) * conj(dy/dx)
341  auto dx = Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx));
342  grad_outputs->push_back(dx);
343  return scope.status();
344}
345REGISTER_GRADIENT_OP("Tan", TanGrad);
346
347Status AtanGrad(const Scope& scope, const Operation& op,
348                const std::vector<Output>& grad_inputs,
349                std::vector<Output>* grad_outputs) {
350  // y = arctan(x)
351  // dy/dx = 1 / (1 + x^2)
352  // dx = dy * (1 / (1 + x^2)
353  auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
354  auto dydx = Reciprocal(scope, Add(scope, one, Square(scope, op.input(0))));
355  auto dx = Mul(scope, grad_inputs[0], dydx);
356  grad_outputs->push_back(dx);
357  return scope.status();
358}
359REGISTER_GRADIENT_OP("Atan", AtanGrad);
360
361// BinaryGradCommon handles the setup for binary ops that broadcast
362// their inputs.
363Status BinaryGradCommon(const Scope& scope, const Operation& op,
364                        std::vector<Output>* grad_outputs, const Output& gx_1,
365                        const Output& gx_2) {
366  auto sx_1 = Shape(scope, op.input(0));
367  auto sx_2 = Shape(scope, op.input(1));
368  auto rx = internal::BroadcastGradientArgs(scope, sx_1, sx_2);
369  auto dx_1 = Reshape(scope, Sum(scope, gx_1, rx.r0), sx_1);
370  auto dx_2 = Reshape(scope, Sum(scope, gx_2, rx.r1), sx_2);
371  grad_outputs->push_back(dx_1);
372  grad_outputs->push_back(dx_2);
373  return scope.status();
374}
375
376Status AddGrad(const Scope& scope, const Operation& op,
377               const std::vector<Output>& grad_inputs,
378               std::vector<Output>* grad_outputs) {
379  // y = x_1 + x_2
380  // dy/dx_1 = dy/dx_2 = 1
381  auto gx_1 = Identity(scope, grad_inputs[0]);
382  auto gx_2 = Identity(scope, grad_inputs[0]);
383  return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
384}
385REGISTER_GRADIENT_OP("Add", AddGrad);
386
387Status SubGrad(const Scope& scope, const Operation& op,
388               const std::vector<Output>& grad_inputs,
389               std::vector<Output>* grad_outputs) {
390  // y = x_1 - x_2
391  // dy/dx_1 = 1
392  // dy/dx_2 = -1
393  auto gx_1 = Identity(scope, grad_inputs[0]);
394  auto gx_2 = Neg(scope, grad_inputs[0]);
395  return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
396}
397REGISTER_GRADIENT_OP("Sub", SubGrad);
398
399Status MulGrad(const Scope& scope, const Operation& op,
400               const std::vector<Output>& grad_inputs,
401               std::vector<Output>* grad_outputs) {
402  auto x_1 = ConjugateHelper(scope, op.input(0));
403  auto x_2 = ConjugateHelper(scope, op.input(1));
404  // y = x_1 * x_2
405  // dy/dx_1 = x_2
406  // dy/dx_2 = x_1
407  auto gx_1 = Mul(scope, grad_inputs[0], x_2);
408  auto gx_2 = Mul(scope, grad_inputs[0], x_1);
409  return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
410}
411REGISTER_GRADIENT_OP("Mul", MulGrad);
412
413Status DivGrad(const Scope& scope, const Operation& op,
414               const std::vector<Output>& grad_inputs,
415               std::vector<Output>* grad_outputs) {
416  auto x_1 = ConjugateHelper(scope, op.input(0));
417  auto x_2 = ConjugateHelper(scope, op.input(1));
418  // y = x_1 / x_2
419  // dy/dx_1 = 1/x_2
420  // dy/dx_2 = -x_1/x_2^2
421  auto gx_1 = Div(scope, grad_inputs[0], x_2);
422  auto gx_2 = Mul(scope, grad_inputs[0],
423                  Div(scope, Div(scope, Neg(scope, x_1), x_2), x_2));
424  return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
425}
426REGISTER_GRADIENT_OP("Div", DivGrad);
427
428Status RealDivGrad(const Scope& scope, const Operation& op,
429                   const std::vector<Output>& grad_inputs,
430                   std::vector<Output>* grad_outputs) {
431  auto x_1 = ConjugateHelper(scope, op.input(0));
432  auto x_2 = ConjugateHelper(scope, op.input(1));
433  // y = x_1 / x_2
434  // dy/dx_1 = 1/x_2
435  // dy/dx_2 = -x_1/x_2^2
436  auto gx_1 = RealDiv(scope, grad_inputs[0], x_2);
437  auto gx_2 = Mul(scope, grad_inputs[0],
438                  RealDiv(scope, RealDiv(scope, Neg(scope, x_1), x_2), x_2));
439  return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
440}
441REGISTER_GRADIENT_OP("RealDiv", RealDivGrad);
442
443Status SquaredDifferenceGrad(const Scope& scope, const Operation& op,
444                             const std::vector<Output>& grad_inputs,
445                             std::vector<Output>* grad_outputs) {
446  auto x_1 = ConjugateHelper(scope, op.input(0));
447  auto x_2 = ConjugateHelper(scope, op.input(1));
448  // y = (x_1 - x_2)^2
449  // dy/dx_1 = 2 * (x_1 - x_2)
450  // dy/dx_2 = -2 * (x_1 - x_2)
451  auto two = Cast(scope, Const(scope, 2), grad_inputs[0].type());
452  auto gx_1 = Mul(scope, grad_inputs[0], Mul(scope, two, Sub(scope, x_1, x_2)));
453  auto gx_2 = Neg(scope, gx_1);
454  return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
455}
456REGISTER_GRADIENT_OP("SquaredDifference", SquaredDifferenceGrad);
457
458Status AddNGrad(const Scope& scope, const Operation& op,
459                const std::vector<Output>& grad_inputs,
460                std::vector<Output>* grad_outputs) {
461  // AddN doesn't support broadcasting, so all the inputs must be the
462  // same shape.
463  // Note:
464  // dy/dx_k = d(x_1 + x_2 + ... + x_n)/dx_k = 1 for all x_k
465  // hence dx_k = dy for all x_k
466  // So the gradient for AddN just transfers the incoming gradient to
467  // all outgoing gradients.
468  auto incoming = Identity(scope, grad_inputs[0]);
469  for (int32 i = 0; i < op.num_inputs(); ++i) {
470    grad_outputs->push_back(incoming);
471  }
472  return scope.status();
473}
474REGISTER_GRADIENT_OP("AddN", AddNGrad);
475
476Status PowGrad(const Scope& scope, const Operation& op,
477               const std::vector<Output>& grad_inputs,
478               std::vector<Output>* grad_outputs) {
479  auto x = ConjugateHelper(scope, op.input(0));
480  auto y = ConjugateHelper(scope, op.input(1));
481  auto z = ConjugateHelper(scope, op.output(0));
482  auto grad = grad_inputs[0];
483  // grad * y * pow(x, y - 1)
484  auto one = Cast(scope, Const(scope, 1.0), y.type());
485  auto gx_1 = Mul(scope,
486                  Mul(scope, grad, y),
487                  Pow(scope, x, Sub(scope, y, one)));
488  // Avoid false singularity at x = 0
489  DataType x_dtype = x.type();
490  auto zero = Cast(scope, Const(scope, 0.0), x_dtype);
491  if (x_dtype == DT_COMPLEX64 || x_dtype == DT_COMPLEX128) {
492    // real(x) < 0 is fine for the complex case
493    auto log_x = Where3(scope,
494                        NotEqual(scope, x, zero),
495                        Log(scope, x),
496                        ZerosLike(scope, x));
497    auto gy_1 = Mul(scope, Mul(scope, grad, z), log_x);
498    return BinaryGradCommon(scope, op, grad_outputs, gx_1, gy_1);
499  } else {
500    // There's no sensible real value to return if x < 0, so return 0
501    auto log_x = Where3(scope,
502                        Greater(scope, x, zero),
503                        Log(scope, x),
504                        ZerosLike(scope, x));
505    auto gy_1 = Mul(scope, Mul(scope, grad, z), log_x);
506    return BinaryGradCommon(scope, op, grad_outputs, gx_1, gy_1);
507  }
508}
509REGISTER_GRADIENT_OP("Pow", PowGrad);
510
511// MaximumMinimumGradCommon adds shared ops to calculate gradients for
512// the binary Maximum and Minimum ops.
513Status MaximumMinimumGradCommon(const Scope& scope, const Operation& op,
514                                const std::vector<Output>& grad_inputs,
515                                std::vector<Output>* grad_outputs,
516                                const Output& comparator) {
517  // comparator is a boolean tensor, with
518  // y = x_1 at points where comparator is true, and x_2 otherwise
519  // Therefore
520  // dy/dx_1 = 1 where comparator is true, and 0 otherwise.
521  // dy/dx_2 = 0 where comparator is true, and 1 otherwise.
522  auto grad = grad_inputs[0];
523  auto zeros = ZerosLike(scope, grad);
524  auto gx_1 = Where3(scope, comparator, grad, zeros);
525  auto gx_2 = Where3(scope, comparator, zeros, grad);
526  return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
527}
528
529Status MaximumGrad(const Scope& scope, const Operation& op,
530                   const std::vector<Output>& grad_inputs,
531                   std::vector<Output>* grad_outputs) {
532  auto comparator = GreaterEqual(scope, op.input(0), op.input(1));
533  return MaximumMinimumGradCommon(scope, op, grad_inputs, grad_outputs,
534                                  comparator);
535}
536REGISTER_GRADIENT_OP("Maximum", MaximumGrad);
537
538Status MinimumGrad(const Scope& scope, const Operation& op,
539                   const std::vector<Output>& grad_inputs,
540                   std::vector<Output>* grad_outputs) {
541  auto comparator = LessEqual(scope, op.input(0), op.input(1));
542  return MaximumMinimumGradCommon(scope, op, grad_inputs, grad_outputs,
543                                  comparator);
544}
545REGISTER_GRADIENT_OP("Minimum", MinimumGrad);
546
547Status RealGrad(const Scope& scope, const Operation& op,
548                const std::vector<Output>& grad_inputs,
549                std::vector<Output>* grad_outputs) {
550  auto zero = Cast(scope, Const(scope, 0.0), op.output(0).type());
551  auto dx = Complex(scope, grad_inputs[0], zero);
552  grad_outputs->push_back(dx);
553  return scope.status();
554}
555REGISTER_GRADIENT_OP("Real", RealGrad);
556
557Status ImagGrad(const Scope& scope, const Operation& op,
558                const std::vector<Output>& grad_inputs,
559                std::vector<Output>* grad_outputs) {
560  auto zero = Cast(scope, Const(scope, 0.0), op.output(0).type());
561  auto dx = Complex(scope, zero, grad_inputs[0]);
562  grad_outputs->push_back(dx);
563  return scope.status();
564}
565REGISTER_GRADIENT_OP("Imag", ImagGrad);
566
567Status ComplexGrad(const Scope& scope, const Operation& op,
568                   const std::vector<Output>& grad_inputs,
569                   std::vector<Output>* grad_outputs) {
570  auto gx_1 = Real(scope, grad_inputs[0]);
571  auto gx_2 = Imag(scope, grad_inputs[0]);
572  return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
573}
574REGISTER_GRADIENT_OP("Complex", ComplexGrad);
575
576Status AngleGrad(const Scope& scope, const Operation& op,
577                 const std::vector<Output>& grad_inputs,
578                 std::vector<Output>* grad_outputs) {
579  // y = Angle(x)
580  // dx = -dy / (Im(x) + iRe(x)) = -dy * z
581  auto re = Real(scope, op.input(0));
582  auto im = Imag(scope, op.input(0));
583  auto z_inv = Reciprocal(scope, Complex(scope, im, re));
584  auto zero = Cast(scope, Const(scope, 0), grad_inputs[0].type());
585  auto grad = Complex(scope, grad_inputs[0], zero);
586  auto dx = Neg(scope, Mul(scope, grad, z_inv));
587  grad_outputs->push_back(dx);
588  return scope.status();
589}
590REGISTER_GRADIENT_OP("Angle", AngleGrad);
591
592Status ConjGrad(const Scope& scope, const Operation& op,
593                const std::vector<Output>& grad_inputs,
594                std::vector<Output>* grad_outputs) {
595  grad_outputs->push_back(Conj(scope, grad_inputs[0]));
596  return scope.status();
597}
598REGISTER_GRADIENT_OP("Conj", ConjGrad);
599
600// Integer division x / y, assuming x and y >=0, but treats x/0 = x
601Output SafeDivHelper(const Scope& scope, const Output& x, const Output& y) {
602  return Div(scope, x, Maximum(scope, y, Const(scope, 1)));
603}
604
605// Helper function for reduction ops.
606//
607// input_shape: 1-D Tensor, the shape of the Tensor being reduced.
608// axes: 1-D Tensor, the reduction axes.
609//   Note that the reduction indices are in the range
610//   -rank(input_shape), rank(input_shape)
611// returns a 1-D Tensor, the output shape as if keep_dims were set to True.
612Output ReducedShapeHelper(const Scope& scope, const Output& input_shape,
613                          const Output& reduction_axes) {
614  auto zero = Const(scope, 0);
615  auto one = Const(scope, 1);
616
617  // Running example in comments
618  // input_shape = [2, 3, 5, 7]
619  // axes = [1, 2]
620  // The result (a shape after a reduction with keep_dims=True)
621  // [2, 1, 1, 7]
622  //
623  // We can treat each entry in axes as an index into input_shape that
624  // should be replaced by 1.
625  // We use DynamicStitch to do this.
626
627  // input_rank = 4
628  auto input_rank = Size(scope, input_shape);
629
630  // Normalize any negative indices in the reduction_axes to positive
631  // values.
632  auto axes = Mod(scope, Add(scope, reduction_axes, input_rank), input_rank);
633
634  // This [0..input_rank) range of integers is used in DynamicStitch to
635  // first copy input_shape to the result.
636  // input_rank_range = [0, 1, 2, 3]
637  auto input_rank_range = Range(scope, zero, input_rank, one);
638
639  // A 1-filled tensor with the same shape as axes. DynamicStitch will
640  // merge these 1s (using axes for indices) to the correct
641  // position in the result.
642  // axes_ones = [1, 1]
643  auto axes_ones = OnesLike(scope, axes);
644
645  // using DynamicStitch:
646  // indices = { input_rank_range, axes }
647  //         = { [0, 1, 2, 3], [1, 2] }
648  // data = { input_shape, axes_ones }
649  //      = { [2, 3, 5, 7], [1, 1] }
650  // The input_rank_range entry in indices first replicates the
651  // input_shape to the result.
652  // The axes entry in indices then moves a 1 to each of its entries,
653  // resulting in
654  // [2, 1, 1, 7]
655  std::vector<Output> indices = {input_rank_range, axes};
656  std::vector<Output> data = {input_shape, axes_ones};
657  return DynamicStitch(scope, indices, data);
658}
659
660// SumGradHelper returns the gradient for the Sum operator, and is used
661// by SumGrad and MeanGrad.
662Output SumGradHelper(const Scope& scope, const Operation& op,
663                     const std::vector<Output>& grad_inputs) {
664  // The partial derivative for any input along a "reduced" dimension
665  // is just 1, so we only need replicate the output gradient on such a
666  // dimension to its "expanded" shape.
667  // Running example:
668  // input is
669  // [[a, b, c],
670  //  [d, e, f]]
671  // reduction_indices = [1]
672  // Sum = [a + b + c, d + e + f]
673  // if the gradient is [g1, g2]
674  // We want the propagated gradient to be
675  // [[g1, g1, g1],
676  //  [g2, g2, g2]]
677
678  // input_shape = [2, 3]
679  auto input_shape = Shape(scope, op.input(0));
680
681  // output_shape_kept_dims = [2, 1]
682  auto output_shape_kept_dims =
683      ReducedShapeHelper(scope, input_shape, op.input(1));
684
685  // This step "flips" any 1s with values from the input_shape, and
686  // replaces remaining entries with 1. This creates a shape that
687  // shows how much each dimension in the incoming gradient should be
688  // replicated.
689  // tile_scaling = [1, 3]
690  auto tile_scaling = SafeDivHelper(scope, input_shape, output_shape_kept_dims);
691
692  // grad = [[g1], [g2]]
693  auto grad = Reshape(scope, grad_inputs[0], output_shape_kept_dims);
694
695  // tile(grad, tile_scaling) = [[g1, g1, g1], [g2, g2, g2]]
696  return Tile(scope, grad, tile_scaling);
697}
698
699Status SumGrad(const Scope& scope, const Operation& op,
700               const std::vector<Output>& grad_inputs,
701               std::vector<Output>* grad_outputs) {
702  grad_outputs->push_back(SumGradHelper(scope, op, grad_inputs));
703
704  // Stop propagation along reduction_indices
705  grad_outputs->push_back(NoGradient());
706  return scope.status();
707}
708REGISTER_GRADIENT_OP("Sum", SumGrad);
709
710Status MeanGrad(const Scope& scope, const Operation& op,
711                const std::vector<Output>& grad_inputs,
712                std::vector<Output>* grad_outputs) {
713  // The Mean gradient is just like the Sum gradient, except that
714  // all gradients are also divided by the size of reduced groups.
715  auto sum_grad = SumGradHelper(scope, op, grad_inputs);
716
717  // The product of all entries in a tensor's shape is the total
718  // number of entries in the tensor. This step calculates
719  // n_input_entries/n_output_entries
720  // = group_size
721  auto input_shape = Shape(scope, op.input(0));
722  auto output_shape = Shape(scope, op.output(0));
723  auto zero = Const(scope, 0);
724  auto group_size = SafeDivHelper(scope, Prod(scope, input_shape, zero),
725                                  Prod(scope, output_shape, zero));
726
727  // propagate sum_grad/group_size
728  grad_outputs->push_back(
729      Div(scope, sum_grad, Cast(scope, group_size, sum_grad.type())));
730
731  // Stop propagation along reduction_indices
732  grad_outputs->push_back(NoGradient());
733  return scope.status();
734}
735REGISTER_GRADIENT_OP("Mean", MeanGrad);
736
737Status ErfGrad(const Scope& scope, const Operation& op,
738               const std::vector<Output>& grad_inputs,
739               std::vector<Output>* grad_outputs) {
740  auto grad = grad_inputs[0];
741  auto two_over_root_pi = Cast(scope, Const(scope, 2 / std::sqrt(M_PI)),
742                               grad.type());
743  Scope grad_scope = scope.WithControlDependencies(grad);
744  auto x = ConjugateHelper(grad_scope, op.input(0));
745  // grad * 2/sqrt(pi) * exp(-x**2)
746  auto dx = Mul(grad_scope,
747                Mul(grad_scope, grad, two_over_root_pi),
748                Exp(grad_scope, Neg(grad_scope, Square(grad_scope, x))));
749  grad_outputs->push_back(dx);
750  return grad_scope.status();
751}
752REGISTER_GRADIENT_OP("Erf", ErfGrad);
753
754Status LgammaGrad(const Scope& scope, const Operation& op,
755                  const std::vector<Output>& grad_inputs,
756                  std::vector<Output>* grad_outputs) {
757  auto grad = grad_inputs[0];
758  Scope grad_scope = scope.WithControlDependencies(grad);
759  auto x = ConjugateHelper(grad_scope, op.input(0));
760  auto dx = Mul(grad_scope, grad, Digamma(grad_scope, x));
761  grad_outputs->push_back(dx);
762  return grad_scope.status();
763}
764REGISTER_GRADIENT_OP("Lgamma", LgammaGrad);
765
766Status SelectGrad(const Scope& scope, const Operation& op,
767                  const std::vector<Output>& grad_inputs,
768                  std::vector<Output>* grad_outputs) {
769  auto comparator = op.input(0);
770  auto x = op.input(1);
771  auto zeros = ZerosLike(scope, x);
772  auto grad = grad_inputs[0];
773
774  auto gx_1 = Where3(scope, comparator, grad, zeros);
775  auto gx_2 = Where3(scope, comparator, zeros, grad);
776
777  grad_outputs->push_back(NoGradient());
778  grad_outputs->push_back(gx_1);
779  grad_outputs->push_back(gx_2);
780  return scope.status();
781}
782REGISTER_GRADIENT_OP("Select", SelectGrad);
783
784Status MinOrMaxGrad(const Scope& scope, const Operation& op,
785                    const std::vector<Output>& grad_inputs,
786                    std::vector<Output>* grad_outputs) {
787  // The partial derivative for any input along a "reduced" dimension
788  // is 1 when it is the min (or max) and 0 everywhere else. So the
789  // gradient calculation is identical for both operators.
790  //
791  // There's a special case for propagating gradients when there are
792  // multiple minima (or maxima) - we choose to divide the gradient
793  // equally among all matching inputs.
794  //
795  // Please note this comment
796  // https://github.com/tensorflow/tensorflow/issues/4886#issuecomment-256836063
797  // for details.
798
799  // Running example:
800  // input: [[5, 5, 5],
801  //         [1, 2, -3]]
802  // reduction_indices: [1]
803  auto input = op.input(0);
804  auto reduction_indices = op.input(1);
805
806  // [2, 3]
807  auto input_shape = Shape(scope, input);
808
809  // [2, 1]
810  auto output_shape_kept_dims =
811      ReducedShapeHelper(scope, input_shape, reduction_indices);
812
813  // for op=min (say)
814  // output = [5, -3]
815  // y = [[5],
816  //      [-3]]
817  auto y = Reshape(scope, op.output(0), output_shape_kept_dims);
818
819  // reshape([g1, g2], [2, 1]) = [[g1],
820  //                              [g2]]
821  auto grad = Reshape(scope, grad_inputs[0], output_shape_kept_dims);
822
823  // indicators = equal(y, input)
824  //  = equal([[5],   [[5, 5, 5],
825  //           [-3]],  [1, 2, -3]])
826  //  = [[1, 1, 1],
827  //     [0, 0, 1]]
828  auto indicators = Cast(scope, Equal(scope, y, input), grad_inputs[0].type());
829
830  // [[3],
831  //  [1]]
832  auto num_selected = Reshape(scope, Sum(scope, indicators, reduction_indices),
833                              output_shape_kept_dims);
834
835  // [[1/3, 1/3, 1/3],
836  //  [0, 0, 1]]
837  auto scale = Div(scope, indicators, num_selected);
838
839  // [[g1/3, g1/3, g1/3],
840  //  [0, 0, g2]]
841  grad_outputs->push_back(Mul(scope, scale, grad));
842
843  // Stop propagation along reduction_indices
844  grad_outputs->push_back(NoGradient());
845  return scope.status();
846}
847REGISTER_GRADIENT_OP("Min", MinOrMaxGrad);
848REGISTER_GRADIENT_OP("Max", MinOrMaxGrad);
849
850Status ProdGrad(const Scope& scope, const Operation& op,
851                const std::vector<Output>& grad_inputs,
852                std::vector<Output>* grad_outputs) {
853  auto zero = Const(scope, 0);
854  auto one = Const(scope, 1);
855
856  // The gradient can be expressed by dividing the product by each entry of
857  // the input tensor. If our input is
858  // [
859  //  [3, 4],
860  //  [5, 6],
861  //  [7, 8]
862  // ]
863  // and we do a Prod operation on the axis 1, we will obtain [[105, 192]].
864  // The gradient will have the same shape as the input
865  //     [
866  //       [105/3, 192/4],
867  // dz *  [105/5, 192/6],
868  //       [105/7, 192/6]
869  //     ]
870  // If the input contains a zero, the division is impossible but
871  // if we take the calculation that gave the first gradient
872  // (3 * 5 * 6)/3 is equal to 5 * 6
873  // the trick will be to cumprod the elements on the axis without
874  // the element at the current position (3 in the example above).
875  // We will take as example:
876  // [
877  //   [
878  //     [3.0, 4.0],
879  //     [5.0, 6.0],
880  //     [7.0, 8.0]
881  //   ],
882  //   [
883  //     [3.0, 5.0],
884  //     [0.0, 6.0],
885  //     [5.0, 6.0]
886  //   ]
887  // ]
888
889  // [2, 3, 2]
890  auto input_shape = Shape(scope, op.input(0));
891
892  // The Reshape with -1 flattens the reduction indices.
893  // [1]
894  auto reduction_indices = Reshape(scope, op.input(1), {-1});
895
896  // [2, 1, 2]
897  auto output_shape_kept_dims =
898      ReducedShapeHelper(scope, input_shape, reduction_indices);
899
900  // [1, 3, 1]
901  auto tile_scaling = SafeDivHelper(scope, input_shape, output_shape_kept_dims);
902
903  // [[[105, 192]], [[0, 180]]]
904  auto grad = Reshape(scope, grad_inputs[0], output_shape_kept_dims);
905
906  // [[[105, 192], [105, 192], [105, 192]], [[0, 180], [0, 180], [0, 180]]]
907  auto grad_tiled = Tile(scope, grad, tile_scaling);
908
909  Scope cpu_scope = scope.WithDevice("/cpu:0");
910
911  // [3]
912  auto rank = Rank(cpu_scope, op.input(0));
913
914
915  // Normalize any negative indices in the reduction_axes to positive values.
916  auto reduction_indices_pos = Mod(cpu_scope, Add(cpu_scope, reduction_indices, rank), rank);
917
918  // [1]
919  auto reduced = Cast(cpu_scope, reduction_indices_pos, DataType::DT_INT32);
920
921  // [0, 1, 2]
922  auto idx = Range(cpu_scope, zero, rank, one);
923
924  // [0, 2]
925  auto other = SetDiff1D(cpu_scope, idx, reduced).out;
926
927  // [1, 0, 2]
928  auto perm =
929      Concat(cpu_scope, std::initializer_list<Input>{reduced, other}, 0);
930
931  // 3 => [3]
932  auto reduced_num = Prod(cpu_scope, Gather(scope, input_shape, reduced), 0);
933
934  // 2 * 2 => [2]
935  auto other_num = Prod(cpu_scope, Gather(scope, input_shape, other), 0);
936
937  // [
938  //    [
939  //       [ 3.,  4.],
940  //       [ 3.,  5.]
941  //   ],
942  //   [
943  //       [ 5.,  6.],
944  //       [ 0.,  6.]
945  //   ],
946  //   [
947  //       [ 7.,  8.],
948  //       [ 5.,  6.]
949  //   ]
950  // ]
951  auto permuted = Transpose(scope, op.input(0), perm);
952
953  // [3, 2, 2]
954  auto permuted_shape = Shape(scope, permuted);
955
956  // [
957  //   [ 3.,  4.,  3.,  5.],
958  //   [ 5.,  6.,  0.,  6.],
959  //   [ 7.,  8.,  5.,  6.]
960  // ]
961  auto reshaped = Reshape(
962      scope, permuted,
963      Stack(scope, std::initializer_list<Input>{reduced_num, other_num}));
964
965  // [
966  //   [ 1.,  1.,  1.,  1.],
967  //   [ 3.,  4.,  3.,  5.],
968  //   [ 15.,  24.,  0.,  30.]
969  // ]
970  auto left = Cumprod(scope, reshaped, zero, Cumprod::Exclusive(true));
971
972  // [
973  //   [ 35.,  48.,  0.,  36.],
974  //   [  7.,   8.,   5.,   6.],
975  //   [  1.,   1.,   1.,   1.]
976  // ]
977  auto right =
978      Cumprod(scope, reshaped, zero, Cumprod::Exclusive(true).Reverse(true));
979
980  // left * right =
981  // [
982  //   [ 35.,  48.,  0.,  36.],
983  //   [ 21.,  32.,  15.,  30.],
984  //   [ 15.,  24.,  0.,  30.]
985  // ]
986  // y =
987  // [
988  //   [
989  //     [ 35.,  48.],
990  //     [ 0.,  36.]
991  //   ],
992  //   [
993  //     [ 21.,  32.],
994  //     [ 15.,  30.]
995  //   ],
996  //   [
997  //     [ 15.,  24.],
998  //     [ 0.,  30.]
999  //   ]
1000  // ]
1001  auto y = Reshape(scope, Mul(scope, left, right), permuted_shape);
1002
1003  // out =
1004  // [
1005  //   [
1006  //     [ 35.,  48.],
1007  //     [ 21.,  32.],
1008  //     [ 15.,  24.]
1009  //   ],
1010  //   [
1011  //     [ 0.,   36.],
1012  //     [ 15.,  30.],
1013  //     [ 0.,  30.]
1014  //   ]
1015  // ]
1016  auto out =
1017      Mul(scope, grad_tiled, Transpose(scope, y, InvertPermutation(scope, perm)));
1018
1019  grad_outputs->push_back(Reshape(scope, out, input_shape));
1020
1021  // stop propagation along reduction_indices
1022  grad_outputs->push_back(NoGradient());
1023  return scope.status();
1024}
1025REGISTER_GRADIENT_OP("Prod", ProdGrad);
1026
1027// MatMulGrad helper function used to compute two MatMul operations
1028// based on input matrix transposition combinations.
1029Status MatMulGradHelper(const Scope& scope, const bool is_batch,
1030                        const Output& x0, const bool adj_x0, const Output& x1,
1031                        const bool adj_x1, const Output& y0, const bool adj_y0,
1032                        const Output& y1, const bool adj_y1,
1033                        std::vector<Output>* grad_outputs) {
1034  if (is_batch == false) {
1035    auto dx =
1036        MatMul(scope, x0, x1, MatMul::TransposeA(adj_x0).TransposeB(adj_x1));
1037    grad_outputs->push_back(dx);
1038    auto dy =
1039        MatMul(scope, y0, y1, MatMul::TransposeA(adj_y0).TransposeB(adj_y1));
1040    grad_outputs->push_back(dy);
1041  } else {
1042    auto dx =
1043        BatchMatMul(scope, x0, x1, BatchMatMul::AdjX(adj_x0).AdjY(adj_x1));
1044    grad_outputs->push_back(dx);
1045    auto dy =
1046        BatchMatMul(scope, y0, y1, BatchMatMul::AdjX(adj_y0).AdjY(adj_y1));
1047    grad_outputs->push_back(dy);
1048  }
1049  return scope.status();
1050}
1051
1052// MatMulGrad common used to read and check node attr state, and determine
1053// proper MatMul products for gradients based on input matrix transposition
1054// combinations.
1055Status MatMulGradCommon(const Scope& scope, const Operation& op,
1056                        const bool is_batch,
1057                        const std::vector<Output>& grad_inputs,
1058                        const string& attr_adj_x, const string& attr_adj_y,
1059                        std::vector<Output>* grad_outputs) {
1060  auto a = op.input(0);
1061  auto b = op.input(1);
1062  // Use conjugate of the inputs for MatMul
1063  if (is_batch == false) {
1064    a = ConjugateHelper(scope, a);
1065    b = ConjugateHelper(scope, b);
1066  }
1067  auto product = op.output(0);
1068
1069  bool ta;
1070  bool tb;
1071  TF_RETURN_IF_ERROR(GetNodeAttr(product.node()->attrs(), attr_adj_x, &ta));
1072  TF_RETURN_IF_ERROR(GetNodeAttr(product.node()->attrs(), attr_adj_y, &tb));
1073
1074  if (!ta && !tb) {
1075    return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, b, true, a,
1076                            true, grad_inputs[0], false, grad_outputs);
1077  } else if (!ta && tb) {
1078    return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, b, false,
1079                            grad_inputs[0], true, a, false, grad_outputs);
1080  } else if (ta && !tb) {
1081    return MatMulGradHelper(scope, is_batch, b, false, grad_inputs[0], true, a,
1082                            false, grad_inputs[0], false, grad_outputs);
1083  }
1084  return MatMulGradHelper(scope, is_batch, b, true, grad_inputs[0], true,
1085                          grad_inputs[0], true, a, true, grad_outputs);
1086}
1087
1088Status MatMulGrad(const Scope& scope, const Operation& op,
1089                  const std::vector<Output>& grad_inputs,
1090                  std::vector<Output>* grad_outputs) {
1091  return MatMulGradCommon(scope, op, false, grad_inputs, "transpose_a",
1092                          "transpose_b", grad_outputs);
1093}
1094REGISTER_GRADIENT_OP("MatMul", MatMulGrad);
1095
1096Status BatchMatMulGrad(const Scope& scope, const Operation& op,
1097                       const std::vector<Output>& grad_inputs,
1098                       std::vector<Output>* grad_outputs) {
1099  return MatMulGradCommon(scope, op, true, grad_inputs, "adj_x", "adj_y",
1100                          grad_outputs);
1101}
1102REGISTER_GRADIENT_OP("BatchMatMul", BatchMatMulGrad);
1103
1104}  // anonymous namespace
1105}  // namespace ops
1106}  // namespace tensorflow
1107