math_grad.cc revision b2823a416f876bfbff3022d1b61e358525b40618
1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/cc/ops/standard_ops.h"
17
18#include "tensorflow/cc/framework/grad_op_registry.h"
19
20namespace tensorflow {
21namespace ops {
22namespace {
23
24// TODO(andydavis) Add control dependencies to gradient functions (as needed).
25
26Status AbsGrad(const Scope& scope, const Operation& op,
27               const std::vector<Output>& grad_inputs,
28               std::vector<Output>* grad_outputs) {
29  // dx = dy * sign(x)
30  grad_outputs->push_back(Mul(scope, grad_inputs[0], Sign(scope, op.input(0))));
31  return scope.status();
32}
33REGISTER_GRADIENT_OP("Abs", AbsGrad);
34
35Status NegGrad(const Scope& scope, const Operation& op,
36               const std::vector<Output>& grad_inputs,
37               std::vector<Output>* grad_outputs) {
38  // dx = -dy;
39  grad_outputs->push_back(Neg(scope, grad_inputs[0]));
40  return scope.status();
41}
42REGISTER_GRADIENT_OP("Neg", NegGrad);
43
44Status InvGrad(const Scope& scope, const Operation& op,
45               const std::vector<Output>& grad_inputs,
46               std::vector<Output>* grad_outputs) {
47  // dx = dy * (-1 * (y * y))
48  grad_outputs->push_back(
49      Mul(scope, grad_inputs[0], Neg(scope, Square(scope, op.output(0)))));
50  return scope.status();
51}
52REGISTER_GRADIENT_OP("Inv", InvGrad);
53REGISTER_GRADIENT_OP("Reciprocal", InvGrad);
54
55Status SquareGrad(const Scope& scope, const Operation& op,
56                  const std::vector<Output>& grad_inputs,
57                  std::vector<Output>* grad_outputs) {
58  // dx = dy * (2 * x)
59  auto two = Cast(scope, Const(scope, 2), op.input(0).type());
60  grad_outputs->push_back(
61      Mul(scope, grad_inputs[0], Mul(scope, two, op.input(0))));
62  return scope.status();
63}
64REGISTER_GRADIENT_OP("Square", SquareGrad);
65
66Status SqrtGrad(const Scope& scope, const Operation& op,
67                const std::vector<Output>& grad_inputs,
68                std::vector<Output>* grad_outputs) {
69  // y = sqrt(x)
70  // dy/dx =  0.5 * (1 / sqrt(x)) = 0.5 * (1 / y)
71  // dx = dy * (0.5 * (1 / y))
72  auto y_inv = Reciprocal(scope, op.output(0));
73  auto half = Cast(scope, Const(scope, 0.5), op.input(0).type());
74  auto dx = Mul(scope, grad_inputs[0], Mul(scope, half, y_inv));
75  grad_outputs->push_back(dx);
76  return scope.status();
77}
78REGISTER_GRADIENT_OP("Sqrt", SqrtGrad);
79
80Status RsqrtGrad(const Scope& scope, const Operation& op,
81                 const std::vector<Output>& grad_inputs,
82                 std::vector<Output>* grad_outputs) {
83  // y = 1/x^1/2 = x^-1/2
84  // dy/dx = -1/2 * x^-3/2 = -1/2 * x^-1/2 * x^-1 = -1/2 * y * x^-1
85  // dx = dy * (-1/2 * y * x^-1)
86  auto x_inv = Reciprocal(scope, op.input(0));
87  auto y = op.output(0);
88  auto neghalf = Cast(scope, Const(scope, -0.5), op.input(0).type());
89  auto a = Mul(scope, neghalf, x_inv);
90  auto b = Mul(scope, a, y);
91  auto dx = Mul(scope, grad_inputs[0], b);
92  grad_outputs->push_back(dx);
93  return scope.status();
94}
95REGISTER_GRADIENT_OP("Rsqrt", RsqrtGrad);
96
97Status ExpGrad(const Scope& scope, const Operation& op,
98               const std::vector<Output>& grad_inputs,
99               std::vector<Output>* grad_outputs) {
100  // y = exp(x)
101  // dy/dx = exp(x)
102  // dx = dy * y
103  grad_outputs->push_back(Mul(scope, grad_inputs[0], op.output(0)));
104  return scope.status();
105}
106REGISTER_GRADIENT_OP("Exp", ExpGrad);
107
108Status LogGrad(const Scope& scope, const Operation& op,
109               const std::vector<Output>& grad_inputs,
110               std::vector<Output>* grad_outputs) {
111  // f(x) = log(x) = y
112  // df/dx = 1 / x
113  // dx = dy * (1 / x)
114  grad_outputs->push_back(
115      Mul(scope, grad_inputs[0], Reciprocal(scope, op.input(0))));
116  return scope.status();
117}
118REGISTER_GRADIENT_OP("Log", LogGrad);
119
120Status Log1pGrad(const Scope& scope, const Operation& op,
121                 const std::vector<Output>& grad_inputs,
122                 std::vector<Output>* grad_outputs) {
123  // f(x) = log1p(x) = y
124  // df/dx = 1 / (1 + x)
125  // dx = dy * (1 / (1 + x))
126  auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
127  grad_outputs->push_back(
128      Div(scope, grad_inputs[0], Add(scope, one, op.input(0))));
129  return scope.status();
130}
131REGISTER_GRADIENT_OP("Log1p", Log1pGrad);
132
133Status TanhGrad(const Scope& scope, const Operation& op,
134                const std::vector<Output>& grad_inputs,
135                std::vector<Output>* grad_outputs) {
136  // y = tanh(x)
137  // dy/dx = 1 - (tanh(x))^2 = 1 - y^2
138  // dx = dy * (1 - y^2)
139  auto y2 = Square(scope, op.output(0));
140  auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
141  auto dx = Mul(scope, grad_inputs[0], Sub(scope, one, y2));
142  grad_outputs->push_back(dx);
143  return scope.status();
144}
145REGISTER_GRADIENT_OP("Tanh", TanhGrad);
146
147Status SigmoidGrad(const Scope& scope, const Operation& op,
148                   const std::vector<Output>& grad_inputs,
149                   std::vector<Output>* grad_outputs) {
150  // y = 1 / (1 + exp(-x))
151  // dy/dx = y * (1 - y)
152  // dx = dy * y * (1 - y)
153  auto y = op.output(0);
154  auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
155  auto dx = Mul(scope, grad_inputs[0], Mul(scope, y, Sub(scope, one, y)));
156  grad_outputs->push_back(dx);
157  return scope.status();
158}
159REGISTER_GRADIENT_OP("Sigmoid", SigmoidGrad);
160
161Status SignGrad(const Scope& scope, const Operation& op,
162                const std::vector<Output>& grad_inputs,
163                std::vector<Output>* grad_outputs) {
164  auto shape = Shape(scope, op.input(0));
165  auto zero = Cast(scope, Const(scope, 0.0), op.input(0).type());
166  auto dx = Fill(scope, shape, zero);
167  grad_outputs->push_back(dx);
168  return scope.status();
169}
170REGISTER_GRADIENT_OP("Sign", SignGrad);
171
172Status SinGrad(const Scope& scope, const Operation& op,
173               const std::vector<Output>& grad_inputs,
174               std::vector<Output>* grad_outputs) {
175  // y = sin(x)
176  // dy/dx = cos(x)
177  // dx = dy * cos(x)
178  auto dx = Mul(scope, grad_inputs[0], Cos(scope, op.input(0)));
179  grad_outputs->push_back(dx);
180  return scope.status();
181}
182REGISTER_GRADIENT_OP("Sin", SinGrad);
183
184Status CosGrad(const Scope& scope, const Operation& op,
185               const std::vector<Output>& grad_inputs,
186               std::vector<Output>* grad_outputs) {
187  // y = cos(x)
188  // dy/dx = -sin(x)
189  // dx = dy * -sin(x)
190  auto dx = Mul(scope, grad_inputs[0], Neg(scope, Sin(scope, op.input(0))));
191  grad_outputs->push_back(dx);
192  return scope.status();
193}
194REGISTER_GRADIENT_OP("Cos", CosGrad);
195
196Status AsinGrad(const Scope& scope, const Operation& op,
197                const std::vector<Output>& grad_inputs,
198                std::vector<Output>* grad_outputs) {
199  // y = asin(x)
200  // dy/dx = 1 / (1 - x * x)^1/2
201  // dx = dy * (1 / (1 - x * x)^1/2)
202  auto x2 = Square(scope, op.input(0));
203  auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
204  auto dydx = Reciprocal(scope, Sqrt(scope, Sub(scope, one, x2)));
205  auto dx = Mul(scope, grad_inputs[0], dydx);
206  grad_outputs->push_back(dx);
207  return scope.status();
208}
209REGISTER_GRADIENT_OP("Asin", AsinGrad);
210
211Status AcosGrad(const Scope& scope, const Operation& op,
212                const std::vector<Output>& grad_inputs,
213                std::vector<Output>* grad_outputs) {
214  // y = acos(x)
215  // dy/dx = - 1 / (1 - x * x)^1/2
216  // dx = dy * (- 1 / (1 - x * x)^1/2)
217  auto x2 = Square(scope, op.input(0));
218  auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
219  auto dydx = Neg(scope, Reciprocal(scope, Sqrt(scope, Sub(scope, one, x2))));
220  auto dx = Mul(scope, grad_inputs[0], dydx);
221  grad_outputs->push_back(dx);
222  return scope.status();
223}
224REGISTER_GRADIENT_OP("Acos", AcosGrad);
225
226Status TanGrad(const Scope& scope, const Operation& op,
227               const std::vector<Output>& grad_inputs,
228               std::vector<Output>* grad_outputs) {
229  // y = tan(x)
230  // dy/dx = sec(x)^2 = 1 / cos(x)^2
231  // dx = dy * (1 / cos(x)^2)
232  auto dydx = Square(scope, Reciprocal(scope, Cos(scope, op.input(0))));
233  auto dx = Mul(scope, grad_inputs[0], dydx);
234  grad_outputs->push_back(dx);
235  return scope.status();
236}
237REGISTER_GRADIENT_OP("Tan", TanGrad);
238
239Status AtanGrad(const Scope& scope, const Operation& op,
240                const std::vector<Output>& grad_inputs,
241                std::vector<Output>* grad_outputs) {
242  // y = arctan(x)
243  // dy/dx = 1 / (1 + x^2)
244  // dx = dy * (1 / (1 + x^2)
245  auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
246  auto dydx = Reciprocal(scope, Add(scope, one, Square(scope, op.input(0))));
247  auto dx = Mul(scope, grad_inputs[0], dydx);
248  grad_outputs->push_back(dx);
249  return scope.status();
250}
251REGISTER_GRADIENT_OP("Atan", AtanGrad);
252
253Status RealGrad(const Scope& scope, const Operation& op,
254                const std::vector<Output>& grad_inputs,
255                std::vector<Output>* grad_outputs) {
256  auto zero = Cast(scope, Const(scope, 0.0), op.output(0).type());
257  auto dx = Complex(scope, grad_inputs[0], zero);
258  grad_outputs->push_back(dx);
259  return scope.status();
260}
261REGISTER_GRADIENT_OP("Real", RealGrad);
262
263Status ImagGrad(const Scope& scope, const Operation& op,
264                const std::vector<Output>& grad_inputs,
265                std::vector<Output>* grad_outputs) {
266  auto zero = Cast(scope, Const(scope, 0.0), op.output(0).type());
267  auto dx = Complex(scope, zero, grad_inputs[0]);
268  grad_outputs->push_back(dx);
269  return scope.status();
270}
271REGISTER_GRADIENT_OP("Imag", ImagGrad);
272
273Status ConjGrad(const Scope& scope, const Operation& op,
274                const std::vector<Output>& grad_inputs,
275                std::vector<Output>* grad_outputs) {
276  grad_outputs->push_back(Conj(scope, grad_inputs[0]));
277  return scope.status();
278}
279REGISTER_GRADIENT_OP("Conj", ConjGrad);
280
281// MatMulGrad helper function used to compute two MatMul operations
282// based on input matrix transposition combinations.
283Status MatMulGradHelper(const Scope& scope, const bool is_batch,
284                        const Output& x0, const bool adj_x0, const Output& x1,
285                        const bool adj_x1, const Output& y0, const bool adj_y0,
286                        const Output& y1, const bool adj_y1,
287                        std::vector<Output>* grad_outputs) {
288  if (is_batch == false) {
289    auto dx =
290        MatMul(scope, x0, x1, MatMul::TransposeA(adj_x0).TransposeB(adj_x1));
291    grad_outputs->push_back(dx);
292    auto dy =
293        MatMul(scope, y0, y1, MatMul::TransposeA(adj_y0).TransposeB(adj_y1));
294    grad_outputs->push_back(dy);
295  } else {
296    auto dx =
297        BatchMatMul(scope, x0, x1, BatchMatMul::AdjX(adj_x0).AdjY(adj_x1));
298    grad_outputs->push_back(dx);
299    auto dy =
300        BatchMatMul(scope, y0, y1, BatchMatMul::AdjX(adj_y0).AdjY(adj_y1));
301    grad_outputs->push_back(dy);
302  }
303  return scope.status();
304}
305
306// MatMulGrad common used to read and check node attr state, and determine
307// proper MatMul products for gradients based on input matrix transposition
308// combinations.
309// TODO(andydavis) Re-use this function for BatchMatMulGrad.
310Status MatMulGradCommon(const Scope& scope, const Operation& op,
311                        const bool is_batch,
312                        const std::vector<Output>& grad_inputs,
313                        const string& attr_adj_x, const string& attr_adj_y,
314                        std::vector<Output>* grad_outputs) {
315  DataType dtype;
316  TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->def(), "T", &dtype));
317  if (dtype == DT_COMPLEX64 || dtype == DT_COMPLEX128) {
318    return errors::Unimplemented(
319        "MatMul gradient for complex data type is not supported yet.");
320  }
321
322  bool ta;
323  bool tb;
324  TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->def(), attr_adj_x, &ta));
325  TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->def(), attr_adj_y, &tb));
326
327  if (!ta && !tb) {
328    return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, op.input(1),
329                            true, op.input(0), true, grad_inputs[0], false,
330                            grad_outputs);
331  } else if (!ta && tb) {
332    return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, op.input(1),
333                            false, grad_inputs[0], true, op.input(0), false,
334                            grad_outputs);
335  } else if (ta && !tb) {
336    return MatMulGradHelper(scope, is_batch, op.input(1), false, grad_inputs[0],
337                            true, op.input(0), false, grad_inputs[0], false,
338                            grad_outputs);
339  }
340  return MatMulGradHelper(scope, is_batch, op.input(1), true, grad_inputs[0],
341                          true, grad_inputs[0], true, op.input(0), true,
342                          grad_outputs);
343}
344
345Status MatMulGrad(const Scope& scope, const Operation& op,
346                  const std::vector<Output>& grad_inputs,
347                  std::vector<Output>* grad_outputs) {
348  return MatMulGradCommon(scope, op, false, grad_inputs, "transpose_a",
349                          "transpose_b", grad_outputs);
350}
351REGISTER_GRADIENT_OP("MatMul", MatMulGrad);
352
353Status BatchMatMulGrad(const Scope& scope, const Operation& op,
354                       const std::vector<Output>& grad_inputs,
355                       std::vector<Output>* grad_outputs) {
356  return MatMulGradCommon(scope, op, true, grad_inputs, "adj_x", "adj_y",
357                          grad_outputs);
358}
359REGISTER_GRADIENT_OP("BatchMatMul", BatchMatMulGrad);
360
361}  // anonymous namespace
362}  // namespace ops
363}  // namespace tensorflow
364