math_grad.cc revision 1fa73c53ab95693f070ce70e6be0c644d83c163a
1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/cc/ops/standard_ops.h"
17
18#include "tensorflow/cc/framework/grad_op_registry.h"
19
20namespace tensorflow {
21namespace ops {
22namespace {
23
24// Conjugate helper function returns the conjugate of an Output if it
25// is complex valued.
26Output ConjugateHelper(const Scope& scope, const Output& out) {
27  DataType dtype = out.type();
28  if (dtype == DT_COMPLEX64 || dtype == DT_COMPLEX128) {
29    return Conj(scope, out);
30  } else {
31    return out;
32  }
33}
34
35// TODO(andydavis) Add control dependencies to gradient functions (as needed).
36
37Status AbsGrad(const Scope& scope, const Operation& op,
38               const std::vector<Output>& grad_inputs,
39               std::vector<Output>* grad_outputs) {
40  // dx = dy * sign(x)
41  grad_outputs->push_back(Mul(scope, grad_inputs[0], Sign(scope, op.input(0))));
42  return scope.status();
43}
44REGISTER_GRADIENT_OP("Abs", AbsGrad);
45
46Status NegGrad(const Scope& scope, const Operation& op,
47               const std::vector<Output>& grad_inputs,
48               std::vector<Output>* grad_outputs) {
49  // dx = -dy;
50  grad_outputs->push_back(Neg(scope, grad_inputs[0]));
51  return scope.status();
52}
53REGISTER_GRADIENT_OP("Neg", NegGrad);
54
55Status InvGrad(const Scope& scope, const Operation& op,
56               const std::vector<Output>& grad_inputs,
57               std::vector<Output>* grad_outputs) {
58  // dy/dx = -1/x^2 = -y^2
59  auto dydx = Neg(scope, Square(scope, op.output(0)));
60  // grad(x) = grad(y) * conj(dy/dx)
61  grad_outputs->push_back(
62      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
63  return scope.status();
64}
65REGISTER_GRADIENT_OP("Inv", InvGrad);
66REGISTER_GRADIENT_OP("Reciprocal", InvGrad);
67
68Status SquareGrad(const Scope& scope, const Operation& op,
69                  const std::vector<Output>& grad_inputs,
70                  std::vector<Output>* grad_outputs) {
71  // dy/dx = (2 * x)
72  auto two = Cast(scope, Const(scope, 2), op.input(0).type());
73  auto dydx = Mul(scope, two, op.input(0));
74  // grad(x) = grad(y) * conj(dy/dx)
75  grad_outputs->push_back(
76      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
77  return scope.status();
78}
79REGISTER_GRADIENT_OP("Square", SquareGrad);
80
81Status SqrtGrad(const Scope& scope, const Operation& op,
82                const std::vector<Output>& grad_inputs,
83                std::vector<Output>* grad_outputs) {
84  // y = sqrt(x)
85  // dy/dx =  0.5 * (1 / sqrt(x)) = 0.5 * (1 / y)
86  auto y_inv = Reciprocal(scope, op.output(0));
87  auto half = Cast(scope, Const(scope, 0.5), op.input(0).type());
88  auto dydx = Mul(scope, half, y_inv);
89  // grad(x) = grad(y) * conj(dy/dx)
90  grad_outputs->push_back(
91      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
92  return scope.status();
93}
94REGISTER_GRADIENT_OP("Sqrt", SqrtGrad);
95
96Status RsqrtGrad(const Scope& scope, const Operation& op,
97                 const std::vector<Output>& grad_inputs,
98                 std::vector<Output>* grad_outputs) {
99  // y = 1/x^1/2 = x^-1/2
100  // dy/dx = -1/2 * x^-3/2 = -1/2 * x^-1/2 * x^-1 = -1/2 * y * x^-1
101  auto x_inv = Reciprocal(scope, op.input(0));
102  auto y = op.output(0);
103  auto neghalf = Cast(scope, Const(scope, -0.5), op.input(0).type());
104  auto a = Mul(scope, neghalf, x_inv);
105  auto dydx = Mul(scope, a, y);
106  // grad(x) = grad(y) * conj(dy/dx)
107  grad_outputs->push_back(
108      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
109  return scope.status();
110}
111REGISTER_GRADIENT_OP("Rsqrt", RsqrtGrad);
112
113Status ExpGrad(const Scope& scope, const Operation& op,
114               const std::vector<Output>& grad_inputs,
115               std::vector<Output>* grad_outputs) {
116  // dy/dx = exp(x) = y
117  // grad(x) = grad(y) * conj(dy/dx)
118  //         = grad(y) * conj(y)
119  grad_outputs->push_back(
120      Mul(scope, grad_inputs[0], ConjugateHelper(scope, op.output(0))));
121  return scope.status();
122}
123REGISTER_GRADIENT_OP("Exp", ExpGrad);
124
125Status Expm1Grad(const Scope& scope, const Operation& op,
126                 const std::vector<Output>& grad_inputs,
127                 std::vector<Output>* grad_outputs) {
128  // y = expm1(x)
129  // dy/dx = exp(x)
130  auto dydx = Exp(scope, op.input(0));
131  // grad(x) = grad(y) * conj(dy/dx)
132  grad_outputs->push_back(
133      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
134  return scope.status();
135}
136REGISTER_GRADIENT_OP("Expm1", Expm1Grad);
137
138Status LogGrad(const Scope& scope, const Operation& op,
139               const std::vector<Output>& grad_inputs,
140               std::vector<Output>* grad_outputs) {
141  // y = log(x)
142  // dy/dx = 1 / x
143  auto dydx = Reciprocal(scope, op.input(0));
144  // grad(x) = grad(y) * conj(dy/dx)
145  grad_outputs->push_back(
146      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
147  return scope.status();
148}
149REGISTER_GRADIENT_OP("Log", LogGrad);
150
151Status Log1pGrad(const Scope& scope, const Operation& op,
152                 const std::vector<Output>& grad_inputs,
153                 std::vector<Output>* grad_outputs) {
154  // y = log1p(x)
155  // dy/dx = 1 / (1 + x)
156  auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
157  auto dydx = Reciprocal(scope, Add(scope, one, op.input(0)));
158  // grad(x) = grad(y) * conj(dy/dx)
159  grad_outputs->push_back(
160      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
161  return scope.status();
162}
163REGISTER_GRADIENT_OP("Log1p", Log1pGrad);
164
165Status TanhGrad(const Scope& scope, const Operation& op,
166                const std::vector<Output>& grad_inputs,
167                std::vector<Output>* grad_outputs) {
168  // y = tanh(x)
169  // dy/dx = 1 - (tanh(x))^2 = 1 - y^2
170  auto y2 = Square(scope, op.output(0));
171  auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
172  auto dydx = Sub(scope, one, y2);
173  // grad(x) = grad(y) * conj(dy/dx)
174  grad_outputs->push_back(
175      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
176  return scope.status();
177}
178REGISTER_GRADIENT_OP("Tanh", TanhGrad);
179
180Status SigmoidGrad(const Scope& scope, const Operation& op,
181                   const std::vector<Output>& grad_inputs,
182                   std::vector<Output>* grad_outputs) {
183  // y = 1 / (1 + exp(-x))
184  // dy/dx = y * (1 - y)
185  auto y = op.output(0);
186  auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
187  auto dydx = Mul(scope, y, Sub(scope, one, y));
188  // dx = dy * y * (1 - y)
189  // grad(x) = grad(y) * conj(dy/dx)
190  grad_outputs->push_back(
191      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
192  return scope.status();
193}
194REGISTER_GRADIENT_OP("Sigmoid", SigmoidGrad);
195
196Status SignGrad(const Scope& scope, const Operation& op,
197                const std::vector<Output>& grad_inputs,
198                std::vector<Output>* grad_outputs) {
199  auto shape = Shape(scope, op.input(0));
200  auto zero = Cast(scope, Const(scope, 0.0), op.input(0).type());
201  auto dx = Fill(scope, shape, zero);
202  grad_outputs->push_back(dx);
203  return scope.status();
204}
205REGISTER_GRADIENT_OP("Sign", SignGrad);
206
207Status SinGrad(const Scope& scope, const Operation& op,
208               const std::vector<Output>& grad_inputs,
209               std::vector<Output>* grad_outputs) {
210  // y = sin(x)
211  // dy/dx = cos(x)
212  auto dydx = Cos(scope, op.input(0));
213  // grad(x) = grad(y) * conj(dy/dx)
214  grad_outputs->push_back(
215      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
216  return scope.status();
217}
218REGISTER_GRADIENT_OP("Sin", SinGrad);
219
220Status CosGrad(const Scope& scope, const Operation& op,
221               const std::vector<Output>& grad_inputs,
222               std::vector<Output>* grad_outputs) {
223  // y = cos(x)
224  // dy/dx = -sin(x)
225  auto dydx = Neg(scope, Sin(scope, op.input(0)));
226  // grad(x) = grad(y) * conj(dy/dx)
227  grad_outputs->push_back(
228      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
229  return scope.status();
230}
231REGISTER_GRADIENT_OP("Cos", CosGrad);
232
233Status AsinGrad(const Scope& scope, const Operation& op,
234                const std::vector<Output>& grad_inputs,
235                std::vector<Output>* grad_outputs) {
236  // y = asin(x)
237  // dy/dx = 1 / sqrt(1 - x^2)
238  auto x2 = Square(scope, op.input(0));
239  auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
240  auto dydx = Reciprocal(scope, Sqrt(scope, Sub(scope, one, x2)));
241  // grad(x) = grad(y) * conj(dy/dx)
242  auto dx = Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx));
243  grad_outputs->push_back(dx);
244  return scope.status();
245}
246REGISTER_GRADIENT_OP("Asin", AsinGrad);
247
248Status AcosGrad(const Scope& scope, const Operation& op,
249                const std::vector<Output>& grad_inputs,
250                std::vector<Output>* grad_outputs) {
251  // y = acos(x)
252  // dy/dx = - 1 / (1 - x * x)^1/2
253  // dx = dy * (- 1 / (1 - x * x)^1/2)
254  auto x2 = Square(scope, op.input(0));
255  auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
256  auto dydx = Neg(scope, Reciprocal(scope, Sqrt(scope, Sub(scope, one, x2))));
257  auto dx = Mul(scope, grad_inputs[0], dydx);
258  grad_outputs->push_back(dx);
259  return scope.status();
260}
261REGISTER_GRADIENT_OP("Acos", AcosGrad);
262
263Status TanGrad(const Scope& scope, const Operation& op,
264               const std::vector<Output>& grad_inputs,
265               std::vector<Output>* grad_outputs) {
266  // y = tan(x)
267  // dy/dx = sec(x)^2 = 1 / cos(x)^2
268  auto dydx = Square(scope, Reciprocal(scope, Cos(scope, op.input(0))));
269  // grad(x) = grad(y) * conj(dy/dx)
270  auto dx = Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx));
271  grad_outputs->push_back(dx);
272  return scope.status();
273}
274REGISTER_GRADIENT_OP("Tan", TanGrad);
275
276Status AtanGrad(const Scope& scope, const Operation& op,
277                const std::vector<Output>& grad_inputs,
278                std::vector<Output>* grad_outputs) {
279  // y = arctan(x)
280  // dy/dx = 1 / (1 + x^2)
281  // dx = dy * (1 / (1 + x^2)
282  auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
283  auto dydx = Reciprocal(scope, Add(scope, one, Square(scope, op.input(0))));
284  auto dx = Mul(scope, grad_inputs[0], dydx);
285  grad_outputs->push_back(dx);
286  return scope.status();
287}
288REGISTER_GRADIENT_OP("Atan", AtanGrad);
289
290Status RealGrad(const Scope& scope, const Operation& op,
291                const std::vector<Output>& grad_inputs,
292                std::vector<Output>* grad_outputs) {
293  auto zero = Cast(scope, Const(scope, 0.0), op.output(0).type());
294  auto dx = Complex(scope, grad_inputs[0], zero);
295  grad_outputs->push_back(dx);
296  return scope.status();
297}
298REGISTER_GRADIENT_OP("Real", RealGrad);
299
300Status ImagGrad(const Scope& scope, const Operation& op,
301                const std::vector<Output>& grad_inputs,
302                std::vector<Output>* grad_outputs) {
303  auto zero = Cast(scope, Const(scope, 0.0), op.output(0).type());
304  auto dx = Complex(scope, zero, grad_inputs[0]);
305  grad_outputs->push_back(dx);
306  return scope.status();
307}
308REGISTER_GRADIENT_OP("Imag", ImagGrad);
309
310Status ConjGrad(const Scope& scope, const Operation& op,
311                const std::vector<Output>& grad_inputs,
312                std::vector<Output>* grad_outputs) {
313  grad_outputs->push_back(Conj(scope, grad_inputs[0]));
314  return scope.status();
315}
316REGISTER_GRADIENT_OP("Conj", ConjGrad);
317
318// MatMulGrad helper function used to compute two MatMul operations
319// based on input matrix transposition combinations.
320Status MatMulGradHelper(const Scope& scope, const bool is_batch,
321                        const Output& x0, const bool adj_x0, const Output& x1,
322                        const bool adj_x1, const Output& y0, const bool adj_y0,
323                        const Output& y1, const bool adj_y1,
324                        std::vector<Output>* grad_outputs) {
325  if (is_batch == false) {
326    auto dx =
327        MatMul(scope, x0, x1, MatMul::TransposeA(adj_x0).TransposeB(adj_x1));
328    grad_outputs->push_back(dx);
329    auto dy =
330        MatMul(scope, y0, y1, MatMul::TransposeA(adj_y0).TransposeB(adj_y1));
331    grad_outputs->push_back(dy);
332  } else {
333    auto dx =
334        BatchMatMul(scope, x0, x1, BatchMatMul::AdjX(adj_x0).AdjY(adj_x1));
335    grad_outputs->push_back(dx);
336    auto dy =
337        BatchMatMul(scope, y0, y1, BatchMatMul::AdjX(adj_y0).AdjY(adj_y1));
338    grad_outputs->push_back(dy);
339  }
340  return scope.status();
341}
342
343// MatMulGrad common used to read and check node attr state, and determine
344// proper MatMul products for gradients based on input matrix transposition
345// combinations.
346// TODO(andydavis) Re-use this function for BatchMatMulGrad.
347Status MatMulGradCommon(const Scope& scope, const Operation& op,
348                        const bool is_batch,
349                        const std::vector<Output>& grad_inputs,
350                        const string& attr_adj_x, const string& attr_adj_y,
351                        std::vector<Output>* grad_outputs) {
352  DataType dtype;
353  TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->attrs(), "T", &dtype));
354  if (dtype == DT_COMPLEX64 || dtype == DT_COMPLEX128) {
355    return errors::Unimplemented(
356        "MatMul gradient for complex data type is not supported yet.");
357  }
358
359  bool ta;
360  bool tb;
361  TF_RETURN_IF_ERROR(
362      GetNodeAttr(op.output(0).node()->attrs(), attr_adj_x, &ta));
363  TF_RETURN_IF_ERROR(
364      GetNodeAttr(op.output(0).node()->attrs(), attr_adj_y, &tb));
365
366  if (!ta && !tb) {
367    return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, op.input(1),
368                            true, op.input(0), true, grad_inputs[0], false,
369                            grad_outputs);
370  } else if (!ta && tb) {
371    return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, op.input(1),
372                            false, grad_inputs[0], true, op.input(0), false,
373                            grad_outputs);
374  } else if (ta && !tb) {
375    return MatMulGradHelper(scope, is_batch, op.input(1), false, grad_inputs[0],
376                            true, op.input(0), false, grad_inputs[0], false,
377                            grad_outputs);
378  }
379  return MatMulGradHelper(scope, is_batch, op.input(1), true, grad_inputs[0],
380                          true, grad_inputs[0], true, op.input(0), true,
381                          grad_outputs);
382}
383
384Status MatMulGrad(const Scope& scope, const Operation& op,
385                  const std::vector<Output>& grad_inputs,
386                  std::vector<Output>* grad_outputs) {
387  return MatMulGradCommon(scope, op, false, grad_inputs, "transpose_a",
388                          "transpose_b", grad_outputs);
389}
390REGISTER_GRADIENT_OP("MatMul", MatMulGrad);
391
392Status BatchMatMulGrad(const Scope& scope, const Operation& op,
393                       const std::vector<Output>& grad_inputs,
394                       std::vector<Output>* grad_outputs) {
395  return MatMulGradCommon(scope, op, true, grad_inputs, "adj_x", "adj_y",
396                          grad_outputs);
397}
398REGISTER_GRADIENT_OP("BatchMatMul", BatchMatMulGrad);
399
400}  // anonymous namespace
401}  // namespace ops
402}  // namespace tensorflow
403