1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <vector>
17
18#include "tensorflow/cc/ops/array_ops_internal.h"
19#include "tensorflow/cc/ops/standard_ops.h"
20#include "tensorflow/core/lib/strings/strcat.h"
21
22#include "tensorflow/cc/framework/grad_op_registry.h"
23#include "tensorflow/cc/framework/gradients.h"
24
25namespace tensorflow {
26namespace ops {
27namespace {
28
29REGISTER_NO_GRADIENT_OP("Const");
30REGISTER_NO_GRADIENT_OP("StopGradient");
31REGISTER_NO_GRADIENT_OP("ConcatOffset");
32REGISTER_NO_GRADIENT_OP("EditDistance");
33REGISTER_NO_GRADIENT_OP("ZerosLike");
34REGISTER_NO_GRADIENT_OP("InvertPermutation");
35REGISTER_NO_GRADIENT_OP("Shape");
36REGISTER_NO_GRADIENT_OP("ShapeN");
37REGISTER_NO_GRADIENT_OP("Rank");
38REGISTER_NO_GRADIENT_OP("Size");
39REGISTER_NO_GRADIENT_OP("BroadcastGradientArgs");
40REGISTER_NO_GRADIENT_OP("OneHot");
41
42Status PackGrad(const Scope& scope, const Operation& op,
43                const std::vector<Output>& grad_inputs,
44                std::vector<Output>* grad_outputs) {
45  int N;
46  TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "N", &N));
47  int axis;
48  TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "axis", &axis));
49
50  grad_outputs->reserve(N);
51  auto grad_op = Unstack(scope, grad_inputs[0], N, Unstack::Axis(axis));
52  for (const Output& o : grad_op.output) {
53    grad_outputs->emplace_back(o);
54  }
55  return scope.status();
56}
57REGISTER_GRADIENT_OP("Pack", PackGrad);
58
59Status UnpackGrad(const Scope& scope, const Operation& op,
60                  const std::vector<Output>& grad_inputs,
61                  std::vector<Output>* grad_outputs) {
62  int axis;
63  TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "axis", &axis));
64  grad_outputs->push_back(Stack(scope, grad_inputs, Stack::Axis(axis)));
65  return scope.status();
66}
67REGISTER_GRADIENT_OP("Unpack", UnpackGrad);
68
69Status IdentityGrad(const Scope& scope, const Operation& op,
70                    const std::vector<Output>& grad_inputs,
71                    std::vector<Output>* grad_outputs) {
72  grad_outputs->push_back(Identity(scope, grad_inputs[0]));
73  return scope.status();
74}
75REGISTER_GRADIENT_OP("Identity", IdentityGrad);
76
77Status RefIdentityGrad(const Scope& scope, const Operation& op,
78                       const std::vector<Output>& grad_inputs,
79                       std::vector<Output>* grad_outputs) {
80  grad_outputs->push_back(Identity(scope, grad_inputs[0]));
81  return scope.status();
82}
83REGISTER_GRADIENT_OP("RefIdentity", RefIdentityGrad);
84
85Status QuantizeAndDequantizeGrad(const Scope& scope, const Operation& op,
86                                 const std::vector<Output>& grad_inputs,
87                                 std::vector<Output>* grad_outputs) {
88  grad_outputs->push_back(Identity(scope, grad_inputs[0]));
89  return scope.status();
90}
91REGISTER_GRADIENT_OP("QuantizeAndDequantize", QuantizeAndDequantizeGrad);
92
93Status QuantizeAndDequantizeV2Grad(const Scope& scope, const Operation& op,
94                                   const std::vector<Output>& grad_inputs,
95                                   std::vector<Output>* grad_outputs) {
96  grad_outputs->push_back(Identity(scope, grad_inputs[0]));
97  grad_outputs->push_back(NoGradient());
98  grad_outputs->push_back(NoGradient());
99  return scope.status();
100}
101REGISTER_GRADIENT_OP("QuantizeAndDequantizeV2", QuantizeAndDequantizeV2Grad);
102
103Status QuantizeAndDequantizeV3Grad(const Scope& scope, const Operation& op,
104                                   const std::vector<Output>& grad_inputs,
105                                   std::vector<Output>* grad_outputs) {
106  grad_outputs->push_back(Identity(scope, grad_inputs[0]));
107  grad_outputs->push_back(NoGradient());
108  grad_outputs->push_back(NoGradient());
109  grad_outputs->push_back(NoGradient());
110  return scope.status();
111}
112REGISTER_GRADIENT_OP("QuantizeAndDequantizeV3", QuantizeAndDequantizeV3Grad);
113
114Status SplitGrad(const Scope& scope, const Operation& op,
115                 const std::vector<Output>& grad_inputs,
116                 std::vector<Output>* grad_outputs) {
117  grad_outputs->push_back(NoGradient());
118  grad_outputs->push_back(Concat(scope, grad_inputs, op.input(0)));
119  return scope.status();
120}
121REGISTER_GRADIENT_OP("Split", SplitGrad);
122
123Status DiagGrad(const Scope& scope, const Operation& op,
124                const std::vector<Output>& grad_inputs,
125                std::vector<Output>* grad_outputs) {
126  grad_outputs->push_back(DiagPart(scope, grad_inputs[0]));
127  return scope.status();
128}
129REGISTER_GRADIENT_OP("Diag", DiagGrad);
130
131Status DiagPartGrad(const Scope& scope, const Operation& op,
132                    const std::vector<Output>& grad_inputs,
133                    std::vector<Output>* grad_outputs) {
134  grad_outputs->push_back(Diag(scope, grad_inputs[0]));
135  return scope.status();
136}
137REGISTER_GRADIENT_OP("DiagPart", DiagPartGrad);
138
139Status MatrixDiagGrad(const Scope& scope, const Operation& op,
140                      const std::vector<Output>& grad_inputs,
141                      std::vector<Output>* grad_outputs) {
142  grad_outputs->push_back(MatrixDiagPart(scope, grad_inputs[0]));
143  return scope.status();
144}
145REGISTER_GRADIENT_OP("MatrixDiag", MatrixDiagGrad);
146
147Status MatrixBandPartGrad(const Scope& scope, const Operation& op,
148                          const std::vector<Output>& grad_inputs,
149                          std::vector<Output>* grad_outputs) {
150  auto num_lower = op.input(1);
151  auto num_upper = op.input(2);
152  grad_outputs->push_back(
153      MatrixBandPart(scope, grad_inputs[0], num_lower, num_upper));
154  grad_outputs->push_back(NoGradient());
155  grad_outputs->push_back(NoGradient());
156  return scope.status();
157}
158REGISTER_GRADIENT_OP("MatrixBandPart", MatrixBandPartGrad);
159
160Status GatherNdGrad(const Scope& scope, const Operation& op,
161                    const std::vector<Output>& grad_inputs,
162                    std::vector<Output>* grad_outputs) {
163  auto ref = op.input(0);
164  auto ref_shape = Shape(scope, ref);
165  auto indices = op.input(1);
166  grad_outputs->push_back(ScatterNd(scope, indices, grad_inputs[0], ref_shape));
167  grad_outputs->push_back(NoGradient());
168  return scope.status();
169}
170REGISTER_GRADIENT_OP("GatherNd", GatherNdGrad);
171
172Status CheckNumericsGrad(const Scope& scope, const Operation& op,
173                         const std::vector<Output>& grad_inputs,
174                         std::vector<Output>* grad_outputs) {
175  string message;
176  TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "message", &message));
177  string err_msg = strings::StrCat(
178      "Not a number (NaN) or infinity (Inf) values detected in gradient. ",
179      message);
180  grad_outputs->push_back(CheckNumerics(scope, grad_inputs[0], err_msg));
181  return scope.status();
182}
183REGISTER_GRADIENT_OP("CheckNumerics", CheckNumericsGrad);
184
185Status ReshapeGrad(const Scope& scope, const Operation& op,
186                   const std::vector<Output>& grad_inputs,
187                   std::vector<Output>* grad_outputs) {
188  auto input_shape = Shape(scope, op.input(0));
189  grad_outputs->push_back(Reshape(scope, grad_inputs[0], input_shape));
190  grad_outputs->push_back(NoGradient());
191  return scope.status();
192}
193REGISTER_GRADIENT_OP("Reshape", ReshapeGrad);
194
195Status ExpandDimsGrad(const Scope& scope, const Operation& op,
196                      const std::vector<Output>& grad_inputs,
197                      std::vector<Output>* grad_outputs) {
198  auto input_shape = Shape(scope, op.input(0));
199  grad_outputs->push_back(Reshape(scope, grad_inputs[0], input_shape));
200  grad_outputs->push_back(NoGradient());
201  return scope.status();
202}
203REGISTER_GRADIENT_OP("ExpandDims", ExpandDimsGrad);
204
205Status SqueezeGrad(const Scope& scope, const Operation& op,
206                   const std::vector<Output>& grad_inputs,
207                   std::vector<Output>* grad_outputs) {
208  auto input_shape = Shape(scope, op.input(0));
209  grad_outputs->push_back(Reshape(scope, grad_inputs[0], input_shape));
210  return scope.status();
211}
212REGISTER_GRADIENT_OP("Squeeze", SqueezeGrad);
213
214Status TransposeGrad(const Scope& scope, const Operation& op,
215                     const std::vector<Output>& grad_inputs,
216                     std::vector<Output>* grad_outputs) {
217  auto inverted_perm = InvertPermutation(scope, op.input(1));
218  grad_outputs->push_back(Transpose(scope, grad_inputs[0], inverted_perm));
219  grad_outputs->push_back(NoGradient());
220  return scope.status();
221}
222REGISTER_GRADIENT_OP("Transpose", TransposeGrad);
223
224Status ReverseSequenceGrad(const Scope& scope, const Operation& op,
225                           const std::vector<Output>& grad_inputs,
226                           std::vector<Output>* grad_outputs) {
227  auto seq_lengths = op.input(1);
228  int batch_dim;
229  TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "batch_dim", &batch_dim));
230  int seq_dim;
231  TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "seq_dim", &seq_dim));
232  grad_outputs->push_back(
233      ReverseSequence(scope, grad_inputs[0], seq_lengths, seq_dim,
234                      ReverseSequence::BatchDim(batch_dim)));
235  grad_outputs->push_back(NoGradient());
236  return scope.status();
237}
238REGISTER_GRADIENT_OP("ReverseSequence", ReverseSequenceGrad);
239
240Status ReverseGrad(const Scope& scope, const Operation& op,
241                   const std::vector<Output>& grad_inputs,
242                   std::vector<Output>* grad_outputs) {
243  auto reverse_dims = op.input(1);
244  grad_outputs->push_back(Reverse(scope, grad_inputs[0], reverse_dims));
245  grad_outputs->push_back(NoGradient());
246  return scope.status();
247}
248REGISTER_GRADIENT_OP("ReverseV2", ReverseGrad);
249
250Status ScatterNdGrad(const Scope& scope, const Operation& op,
251                     const std::vector<Output>& grad_inputs,
252                     std::vector<Output>* grad_outputs) {
253  auto indices = op.input(0);
254  grad_outputs->push_back(NoGradient());
255  grad_outputs->push_back(GatherNd(scope, grad_inputs[0], indices));
256  grad_outputs->push_back(NoGradient());
257  return scope.status();
258}
259REGISTER_GRADIENT_OP("ScatterNd", ScatterNdGrad);
260
261Status ScatterNdNonAliasingAddGrad(const Scope& scope, const Operation& op,
262                                   const std::vector<Output>& grad_inputs,
263                                   std::vector<Output>* grad_outputs) {
264  auto indices = op.input(1);
265  grad_outputs->push_back(Identity(scope, grad_inputs[0]));
266  grad_outputs->push_back(NoGradient());
267  grad_outputs->push_back(GatherNd(scope, grad_inputs[0], indices));
268  return scope.status();
269}
270REGISTER_GRADIENT_OP("ScatterNdNonAliasingAdd", ScatterNdNonAliasingAddGrad);
271
272template <bool IsPadV2>
273Status PadGrad(const Scope& scope, const Operation& op,
274               const std::vector<Output>& grad_inputs,
275               std::vector<Output>* grad_outputs) {
276  auto x = op.input(0);
277  auto a = op.input(1);  // [Rank(x), 2]
278  // Takes a slice of a. The 1st column. [Rank(x), 1].
279  auto size = Stack(scope, {Rank(scope, x), 1});
280  auto pad_before = Slice(scope, a, {0, 0}, size);
281  // Make it a 1-D tensor.
282  auto begin = Reshape(scope, pad_before, {-1});
283  grad_outputs->push_back(Slice(scope, grad_inputs[0], begin, Shape(scope, x)));
284  grad_outputs->push_back(NoGradient());
285  // PadV2 adds a "constant_values" input.
286  if (IsPadV2) {
287    grad_outputs->push_back(NoGradient());
288  }
289  return scope.status();
290}
291REGISTER_GRADIENT_OP("Pad", PadGrad<false>);
292REGISTER_GRADIENT_OP("PadV2", PadGrad<true>);
293
294Status SpaceToBatchGrad(const Scope& scope, const Operation& op,
295                        const std::vector<Output>& grad_inputs,
296                        std::vector<Output>* grad_outputs) {
297  int block_size;
298  TF_RETURN_IF_ERROR(
299      GetNodeAttr(op.node()->attrs(), "block_size", &block_size));
300  grad_outputs->push_back(
301      BatchToSpace(scope, grad_inputs[0], op.input(1), block_size));
302  grad_outputs->push_back(NoGradient());
303  return scope.status();
304}
305REGISTER_GRADIENT_OP("SpaceToBatch", SpaceToBatchGrad);
306
307Status SpaceToBatchNDGrad(const Scope& scope, const Operation& op,
308                          const std::vector<Output>& grad_inputs,
309                          std::vector<Output>* grad_outputs) {
310  grad_outputs->push_back(
311      BatchToSpaceND(scope, grad_inputs[0], op.input(1), op.input(2)));
312  grad_outputs->push_back(NoGradient());
313  grad_outputs->push_back(NoGradient());
314  return scope.status();
315}
316REGISTER_GRADIENT_OP("SpaceToBatchND", SpaceToBatchNDGrad);
317
318Status BatchToSpaceGrad(const Scope& scope, const Operation& op,
319                        const std::vector<Output>& grad_inputs,
320                        std::vector<Output>* grad_outputs) {
321  int block_size;
322  TF_RETURN_IF_ERROR(
323      GetNodeAttr(op.node()->attrs(), "block_size", &block_size));
324  grad_outputs->push_back(
325      SpaceToBatch(scope, grad_inputs[0], op.input(1), block_size));
326  grad_outputs->push_back(NoGradient());
327  return scope.status();
328}
329REGISTER_GRADIENT_OP("BatchToSpace", BatchToSpaceGrad);
330
331Status BatchToSpaceNDGrad(const Scope& scope, const Operation& op,
332                          const std::vector<Output>& grad_inputs,
333                          std::vector<Output>* grad_outputs) {
334  grad_outputs->push_back(
335      SpaceToBatchND(scope, grad_inputs[0], op.input(1), op.input(2)));
336  grad_outputs->push_back(NoGradient());
337  grad_outputs->push_back(NoGradient());
338  return scope.status();
339}
340REGISTER_GRADIENT_OP("BatchToSpaceND", BatchToSpaceNDGrad);
341
342Status SpaceToDepthGrad(const Scope& scope, const Operation& op,
343                        const std::vector<Output>& grad_inputs,
344                        std::vector<Output>* grad_outputs) {
345  int block_size;
346  TF_RETURN_IF_ERROR(
347      GetNodeAttr(op.node()->attrs(), "block_size", &block_size));
348  grad_outputs->push_back(DepthToSpace(scope, grad_inputs[0], block_size));
349  return scope.status();
350}
351REGISTER_GRADIENT_OP("SpaceToDepth", SpaceToDepthGrad);
352
353Status DepthToSpaceGrad(const Scope& scope, const Operation& op,
354                        const std::vector<Output>& grad_inputs,
355                        std::vector<Output>* grad_outputs) {
356  int block_size;
357  TF_RETURN_IF_ERROR(
358      GetNodeAttr(op.node()->attrs(), "block_size", &block_size));
359  grad_outputs->push_back(SpaceToDepth(scope, grad_inputs[0], block_size));
360  return scope.status();
361}
362REGISTER_GRADIENT_OP("DepthToSpace", DepthToSpaceGrad);
363
364Status MirrorPadGrad(const Scope& scope, const Operation& op,
365                     const std::vector<Output>& grad_inputs,
366                     std::vector<Output>* grad_outputs) {
367  string mode;
368  TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "mode", &mode));
369  grad_outputs->push_back(tensorflow::ops::internal::MirrorPadGrad(
370      scope, grad_inputs[0], op.input(1), mode));
371  grad_outputs->push_back(NoGradient());
372  return scope.status();
373}
374REGISTER_GRADIENT_OP("MirrorPad", MirrorPadGrad);
375
376// TODO(suharshs): b/34770860. This gradient was within 1e-3 but not 1e-4.
377Status MirrorPadGradGrad(const Scope& scope, const Operation& op,
378                         const std::vector<Output>& grad_inputs,
379                         std::vector<Output>* grad_outputs) {
380  string mode;
381  TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "mode", &mode));
382  grad_outputs->push_back(MirrorPad(scope, grad_inputs[0], op.input(1), mode));
383  grad_outputs->push_back(NoGradient());
384  return scope.status();
385}
386REGISTER_GRADIENT_OP("MirrorPadGrad", MirrorPadGradGrad);
387
388}  // anonymous namespace
389}  // namespace ops
390}  // namespace tensorflow
391