math_grad.cc revision 6731e21360ac2fd2fa4999a12e87ae72acfcb2d2
1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/cc/ops/standard_ops.h"
17
18#include "tensorflow/cc/framework/grad_op_registry.h"
19
20namespace tensorflow {
21namespace ops {
22namespace {
23
24// TODO(andydavis) Add control dependencies to gradient functions (as needed).
25
26Status AbsGrad(const Scope& scope, const Operation& op,
27               const std::vector<Output>& grad_inputs,
28               std::vector<Output>* grad_outputs) {
29  // dx = dy * sign(x)
30  grad_outputs->push_back(Mul(scope, grad_inputs[0], Sign(scope, op.input(0))));
31  return scope.status();
32}
33REGISTER_GRADIENT_OP("Abs", AbsGrad);
34
35Status NegGrad(const Scope& scope, const Operation& op,
36               const std::vector<Output>& grad_inputs,
37               std::vector<Output>* grad_outputs) {
38  // dx = -dy;
39  grad_outputs->push_back(Neg(scope, grad_inputs[0]));
40  return scope.status();
41}
42REGISTER_GRADIENT_OP("Neg", NegGrad);
43
44Status InvGrad(const Scope& scope, const Operation& op,
45               const std::vector<Output>& grad_inputs,
46               std::vector<Output>* grad_outputs) {
47  // dx = dy * (-1 * (y * y))
48  grad_outputs->push_back(
49      Mul(scope, grad_inputs[0], Neg(scope, Square(scope, op.output(0)))));
50  return scope.status();
51}
52REGISTER_GRADIENT_OP("Inv", InvGrad);
53REGISTER_GRADIENT_OP("Reciprocal", InvGrad);
54
55Status SquareGrad(const Scope& scope, const Operation& op,
56                  const std::vector<Output>& grad_inputs,
57                  std::vector<Output>* grad_outputs) {
58  // dx = dy * (2 * x)
59  auto two = Cast(scope, Const(scope, 2), op.input(0).type());
60  grad_outputs->push_back(
61      Mul(scope, grad_inputs[0], Mul(scope, two, op.input(0))));
62  return scope.status();
63}
64REGISTER_GRADIENT_OP("Square", SquareGrad);
65
66Status SqrtGrad(const Scope& scope, const Operation& op,
67                const std::vector<Output>& grad_inputs,
68                std::vector<Output>* grad_outputs) {
69  // y = sqrt(x)
70  // dy/dx =  0.5 * (1 / sqrt(x)) = 0.5 * (1 / y)
71  // dx = dy * (0.5 * (1 / y))
72  auto y_inv = Reciprocal(scope, op.output(0));
73  auto half = Cast(scope, Const(scope, 0.5), op.input(0).type());
74  auto dx = Mul(scope, grad_inputs[0], Mul(scope, half, y_inv));
75  grad_outputs->push_back(dx);
76  return scope.status();
77}
78REGISTER_GRADIENT_OP("Sqrt", SqrtGrad);
79
80Status RsqrtGrad(const Scope& scope, const Operation& op,
81                 const std::vector<Output>& grad_inputs,
82                 std::vector<Output>* grad_outputs) {
83  // y = 1/x^1/2 = x^-1/2
84  // dy/dx = -1/2 * x^-3/2 = -1/2 * x^-1/2 * x^-1 = -1/2 * y * x^-1
85  // dx = dy * (-1/2 * y * x^-1)
86  auto x_inv = Reciprocal(scope, op.input(0));
87  auto y = op.output(0);
88  auto neghalf = Cast(scope, Const(scope, -0.5), op.input(0).type());
89  auto a = Mul(scope, neghalf, x_inv);
90  auto b = Mul(scope, a, y);
91  auto dx = Mul(scope, grad_inputs[0], b);
92  grad_outputs->push_back(dx);
93  return scope.status();
94}
95REGISTER_GRADIENT_OP("Rsqrt", RsqrtGrad);
96
97Status ExpGrad(const Scope& scope, const Operation& op,
98               const std::vector<Output>& grad_inputs,
99               std::vector<Output>* grad_outputs) {
100  // y = exp(x)
101  // dy/dx = exp(x)
102  // dx = dy * y
103  grad_outputs->push_back(Mul(scope, grad_inputs[0], op.output(0)));
104  return scope.status();
105}
106REGISTER_GRADIENT_OP("Exp", ExpGrad);
107
108Status Expm1Grad(const Scope& scope, const Operation& op,
109                 const std::vector<Output>& grad_inputs,
110                 std::vector<Output>* grad_outputs) {
111  // f(x) = expm1(x)
112  // df/dx = exp(x)
113  // dx = dy * exp(x)
114  grad_outputs->push_back(Mul(scope, grad_inputs[0], Exp(scope, op.input(0))));
115  return scope.status();
116}
117REGISTER_GRADIENT_OP("Expm1", Expm1Grad);
118
119Status LogGrad(const Scope& scope, const Operation& op,
120               const std::vector<Output>& grad_inputs,
121               std::vector<Output>* grad_outputs) {
122  // f(x) = log(x) = y
123  // df/dx = 1 / x
124  // dx = dy * (1 / x)
125  grad_outputs->push_back(
126      Mul(scope, grad_inputs[0], Reciprocal(scope, op.input(0))));
127  return scope.status();
128}
129REGISTER_GRADIENT_OP("Log", LogGrad);
130
131Status Log1pGrad(const Scope& scope, const Operation& op,
132                 const std::vector<Output>& grad_inputs,
133                 std::vector<Output>* grad_outputs) {
134  // f(x) = log1p(x) = y
135  // df/dx = 1 / (1 + x)
136  // dx = dy * (1 / (1 + x))
137  auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
138  grad_outputs->push_back(
139      Div(scope, grad_inputs[0], Add(scope, one, op.input(0))));
140  return scope.status();
141}
142REGISTER_GRADIENT_OP("Log1p", Log1pGrad);
143
144Status TanhGrad(const Scope& scope, const Operation& op,
145                const std::vector<Output>& grad_inputs,
146                std::vector<Output>* grad_outputs) {
147  // y = tanh(x)
148  // dy/dx = 1 - (tanh(x))^2 = 1 - y^2
149  // dx = dy * (1 - y^2)
150  auto y2 = Square(scope, op.output(0));
151  auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
152  auto dx = Mul(scope, grad_inputs[0], Sub(scope, one, y2));
153  grad_outputs->push_back(dx);
154  return scope.status();
155}
156REGISTER_GRADIENT_OP("Tanh", TanhGrad);
157
158Status SigmoidGrad(const Scope& scope, const Operation& op,
159                   const std::vector<Output>& grad_inputs,
160                   std::vector<Output>* grad_outputs) {
161  // y = 1 / (1 + exp(-x))
162  // dy/dx = y * (1 - y)
163  // dx = dy * y * (1 - y)
164  auto y = op.output(0);
165  auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
166  auto dx = Mul(scope, grad_inputs[0], Mul(scope, y, Sub(scope, one, y)));
167  grad_outputs->push_back(dx);
168  return scope.status();
169}
170REGISTER_GRADIENT_OP("Sigmoid", SigmoidGrad);
171
172Status SignGrad(const Scope& scope, const Operation& op,
173                const std::vector<Output>& grad_inputs,
174                std::vector<Output>* grad_outputs) {
175  auto shape = Shape(scope, op.input(0));
176  auto zero = Cast(scope, Const(scope, 0.0), op.input(0).type());
177  auto dx = Fill(scope, shape, zero);
178  grad_outputs->push_back(dx);
179  return scope.status();
180}
181REGISTER_GRADIENT_OP("Sign", SignGrad);
182
183Status SinGrad(const Scope& scope, const Operation& op,
184               const std::vector<Output>& grad_inputs,
185               std::vector<Output>* grad_outputs) {
186  // y = sin(x)
187  // dy/dx = cos(x)
188  // dx = dy * cos(x)
189  auto dx = Mul(scope, grad_inputs[0], Cos(scope, op.input(0)));
190  grad_outputs->push_back(dx);
191  return scope.status();
192}
193REGISTER_GRADIENT_OP("Sin", SinGrad);
194
195Status CosGrad(const Scope& scope, const Operation& op,
196               const std::vector<Output>& grad_inputs,
197               std::vector<Output>* grad_outputs) {
198  // y = cos(x)
199  // dy/dx = -sin(x)
200  // dx = dy * -sin(x)
201  auto dx = Mul(scope, grad_inputs[0], Neg(scope, Sin(scope, op.input(0))));
202  grad_outputs->push_back(dx);
203  return scope.status();
204}
205REGISTER_GRADIENT_OP("Cos", CosGrad);
206
207Status AsinGrad(const Scope& scope, const Operation& op,
208                const std::vector<Output>& grad_inputs,
209                std::vector<Output>* grad_outputs) {
210  // y = asin(x)
211  // dy/dx = 1 / (1 - x * x)^1/2
212  // dx = dy * (1 / (1 - x * x)^1/2)
213  auto x2 = Square(scope, op.input(0));
214  auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
215  auto dydx = Reciprocal(scope, Sqrt(scope, Sub(scope, one, x2)));
216  auto dx = Mul(scope, grad_inputs[0], dydx);
217  grad_outputs->push_back(dx);
218  return scope.status();
219}
220REGISTER_GRADIENT_OP("Asin", AsinGrad);
221
222Status AcosGrad(const Scope& scope, const Operation& op,
223                const std::vector<Output>& grad_inputs,
224                std::vector<Output>* grad_outputs) {
225  // y = acos(x)
226  // dy/dx = - 1 / (1 - x * x)^1/2
227  // dx = dy * (- 1 / (1 - x * x)^1/2)
228  auto x2 = Square(scope, op.input(0));
229  auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
230  auto dydx = Neg(scope, Reciprocal(scope, Sqrt(scope, Sub(scope, one, x2))));
231  auto dx = Mul(scope, grad_inputs[0], dydx);
232  grad_outputs->push_back(dx);
233  return scope.status();
234}
235REGISTER_GRADIENT_OP("Acos", AcosGrad);
236
237Status TanGrad(const Scope& scope, const Operation& op,
238               const std::vector<Output>& grad_inputs,
239               std::vector<Output>* grad_outputs) {
240  // y = tan(x)
241  // dy/dx = sec(x)^2 = 1 / cos(x)^2
242  // dx = dy * (1 / cos(x)^2)
243  auto dydx = Square(scope, Reciprocal(scope, Cos(scope, op.input(0))));
244  auto dx = Mul(scope, grad_inputs[0], dydx);
245  grad_outputs->push_back(dx);
246  return scope.status();
247}
248REGISTER_GRADIENT_OP("Tan", TanGrad);
249
250Status AtanGrad(const Scope& scope, const Operation& op,
251                const std::vector<Output>& grad_inputs,
252                std::vector<Output>* grad_outputs) {
253  // y = arctan(x)
254  // dy/dx = 1 / (1 + x^2)
255  // dx = dy * (1 / (1 + x^2)
256  auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
257  auto dydx = Reciprocal(scope, Add(scope, one, Square(scope, op.input(0))));
258  auto dx = Mul(scope, grad_inputs[0], dydx);
259  grad_outputs->push_back(dx);
260  return scope.status();
261}
262REGISTER_GRADIENT_OP("Atan", AtanGrad);
263
264Status RealGrad(const Scope& scope, const Operation& op,
265                const std::vector<Output>& grad_inputs,
266                std::vector<Output>* grad_outputs) {
267  auto zero = Cast(scope, Const(scope, 0.0), op.output(0).type());
268  auto dx = Complex(scope, grad_inputs[0], zero);
269  grad_outputs->push_back(dx);
270  return scope.status();
271}
272REGISTER_GRADIENT_OP("Real", RealGrad);
273
274Status ImagGrad(const Scope& scope, const Operation& op,
275                const std::vector<Output>& grad_inputs,
276                std::vector<Output>* grad_outputs) {
277  auto zero = Cast(scope, Const(scope, 0.0), op.output(0).type());
278  auto dx = Complex(scope, zero, grad_inputs[0]);
279  grad_outputs->push_back(dx);
280  return scope.status();
281}
282REGISTER_GRADIENT_OP("Imag", ImagGrad);
283
284Status ConjGrad(const Scope& scope, const Operation& op,
285                const std::vector<Output>& grad_inputs,
286                std::vector<Output>* grad_outputs) {
287  grad_outputs->push_back(Conj(scope, grad_inputs[0]));
288  return scope.status();
289}
290REGISTER_GRADIENT_OP("Conj", ConjGrad);
291
292// MatMulGrad helper function used to compute two MatMul operations
293// based on input matrix transposition combinations.
294Status MatMulGradHelper(const Scope& scope, const bool is_batch,
295                        const Output& x0, const bool adj_x0, const Output& x1,
296                        const bool adj_x1, const Output& y0, const bool adj_y0,
297                        const Output& y1, const bool adj_y1,
298                        std::vector<Output>* grad_outputs) {
299  if (is_batch == false) {
300    auto dx =
301        MatMul(scope, x0, x1, MatMul::TransposeA(adj_x0).TransposeB(adj_x1));
302    grad_outputs->push_back(dx);
303    auto dy =
304        MatMul(scope, y0, y1, MatMul::TransposeA(adj_y0).TransposeB(adj_y1));
305    grad_outputs->push_back(dy);
306  } else {
307    auto dx =
308        BatchMatMul(scope, x0, x1, BatchMatMul::AdjX(adj_x0).AdjY(adj_x1));
309    grad_outputs->push_back(dx);
310    auto dy =
311        BatchMatMul(scope, y0, y1, BatchMatMul::AdjX(adj_y0).AdjY(adj_y1));
312    grad_outputs->push_back(dy);
313  }
314  return scope.status();
315}
316
317// MatMulGrad common used to read and check node attr state, and determine
318// proper MatMul products for gradients based on input matrix transposition
319// combinations.
320// TODO(andydavis) Re-use this function for BatchMatMulGrad.
321Status MatMulGradCommon(const Scope& scope, const Operation& op,
322                        const bool is_batch,
323                        const std::vector<Output>& grad_inputs,
324                        const string& attr_adj_x, const string& attr_adj_y,
325                        std::vector<Output>* grad_outputs) {
326  DataType dtype;
327  TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->def(), "T", &dtype));
328  if (dtype == DT_COMPLEX64 || dtype == DT_COMPLEX128) {
329    return errors::Unimplemented(
330        "MatMul gradient for complex data type is not supported yet.");
331  }
332
333  bool ta;
334  bool tb;
335  TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->def(), attr_adj_x, &ta));
336  TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->def(), attr_adj_y, &tb));
337
338  if (!ta && !tb) {
339    return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, op.input(1),
340                            true, op.input(0), true, grad_inputs[0], false,
341                            grad_outputs);
342  } else if (!ta && tb) {
343    return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, op.input(1),
344                            false, grad_inputs[0], true, op.input(0), false,
345                            grad_outputs);
346  } else if (ta && !tb) {
347    return MatMulGradHelper(scope, is_batch, op.input(1), false, grad_inputs[0],
348                            true, op.input(0), false, grad_inputs[0], false,
349                            grad_outputs);
350  }
351  return MatMulGradHelper(scope, is_batch, op.input(1), true, grad_inputs[0],
352                          true, grad_inputs[0], true, op.input(0), true,
353                          grad_outputs);
354}
355
356Status MatMulGrad(const Scope& scope, const Operation& op,
357                  const std::vector<Output>& grad_inputs,
358                  std::vector<Output>* grad_outputs) {
359  return MatMulGradCommon(scope, op, false, grad_inputs, "transpose_a",
360                          "transpose_b", grad_outputs);
361}
362REGISTER_GRADIENT_OP("MatMul", MatMulGrad);
363
364Status BatchMatMulGrad(const Scope& scope, const Operation& op,
365                       const std::vector<Output>& grad_inputs,
366                       std::vector<Output>* grad_outputs) {
367  return MatMulGradCommon(scope, op, true, grad_inputs, "adj_x", "adj_y",
368                          grad_outputs);
369}
370REGISTER_GRADIENT_OP("BatchMatMul", BatchMatMulGrad);
371
372}  // anonymous namespace
373}  // namespace ops
374}  // namespace tensorflow
375