array_grad.cc revision d3a41971ccc7c03e8b476a48dc8aa7fbb983431c
1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <vector>
17
18#include "tensorflow/cc/ops/array_ops_internal.h"
19#include "tensorflow/cc/ops/standard_ops.h"
20
21#include "tensorflow/cc/framework/grad_op_registry.h"
22#include "tensorflow/cc/framework/gradients.h"
23
24namespace tensorflow {
25namespace ops {
26namespace {
27
28REGISTER_NO_GRADIENT_OP("Const");
29REGISTER_NO_GRADIENT_OP("StopGradient");
30REGISTER_NO_GRADIENT_OP("ConcatOffset");
31REGISTER_NO_GRADIENT_OP("EditDistance");
32REGISTER_NO_GRADIENT_OP("ZerosLike");
33REGISTER_NO_GRADIENT_OP("InvertPermutation");
34REGISTER_NO_GRADIENT_OP("Shape");
35REGISTER_NO_GRADIENT_OP("ShapeN");
36REGISTER_NO_GRADIENT_OP("Rank");
37REGISTER_NO_GRADIENT_OP("Size");
38REGISTER_NO_GRADIENT_OP("BroadcastGradientArgs");
39REGISTER_NO_GRADIENT_OP("OneHot");
40
41Status PackGrad(const Scope& scope, const Operation& op,
42                const std::vector<Output>& grad_inputs,
43                std::vector<Output>* grad_outputs) {
44  int N;
45  TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "N", &N));
46  int axis;
47  TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "axis", &axis));
48
49  grad_outputs->reserve(N);
50  auto grad_op = Unstack(scope, grad_inputs[0], N, Unstack::Axis(axis));
51  for (const Output& o : grad_op.output) {
52    grad_outputs->emplace_back(o);
53  }
54  return scope.status();
55}
56REGISTER_GRADIENT_OP("Pack", PackGrad);
57
58Status UnpackGrad(const Scope& scope, const Operation& op,
59                  const std::vector<Output>& grad_inputs,
60                  std::vector<Output>* grad_outputs) {
61  int axis;
62  TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "axis", &axis));
63  grad_outputs->push_back(Stack(scope, grad_inputs, Stack::Axis(axis)));
64  return scope.status();
65}
66REGISTER_GRADIENT_OP("Unpack", UnpackGrad);
67
68Status IdentityGrad(const Scope& scope, const Operation& op,
69                    const std::vector<Output>& grad_inputs,
70                    std::vector<Output>* grad_outputs) {
71  grad_outputs->push_back(Identity(scope, grad_inputs[0]));
72  return scope.status();
73}
74REGISTER_GRADIENT_OP("Identity", IdentityGrad);
75
76Status RefIdentityGrad(const Scope& scope, const Operation& op,
77                       const std::vector<Output>& grad_inputs,
78                       std::vector<Output>* grad_outputs) {
79  grad_outputs->push_back(Identity(scope, grad_inputs[0]));
80  return scope.status();
81}
82REGISTER_GRADIENT_OP("RefIdentity", RefIdentityGrad);
83
84Status QuantizeAndDequantizeGrad(const Scope& scope, const Operation& op,
85                                 const std::vector<Output>& grad_inputs,
86                                 std::vector<Output>* grad_outputs) {
87  grad_outputs->push_back(Identity(scope, grad_inputs[0]));
88  return scope.status();
89}
90REGISTER_GRADIENT_OP("QuantizeAndDequantize", QuantizeAndDequantizeGrad);
91REGISTER_GRADIENT_OP("QuantizeAndDequantizeV2", QuantizeAndDequantizeGrad);
92
93Status SplitGrad(const Scope& scope, const Operation& op,
94                 const std::vector<Output>& grad_inputs,
95                 std::vector<Output>* grad_outputs) {
96  grad_outputs->push_back(NoGradient());
97  grad_outputs->push_back(Concat(scope, grad_inputs, op.input(0)));
98  return scope.status();
99}
100REGISTER_GRADIENT_OP("Split", SplitGrad);
101
102Status DiagGrad(const Scope& scope, const Operation& op,
103                const std::vector<Output>& grad_inputs,
104                std::vector<Output>* grad_outputs) {
105  grad_outputs->push_back(DiagPart(scope, grad_inputs[0]));
106  return scope.status();
107}
108REGISTER_GRADIENT_OP("Diag", DiagGrad);
109
110Status DiagPartGrad(const Scope& scope, const Operation& op,
111                    const std::vector<Output>& grad_inputs,
112                    std::vector<Output>* grad_outputs) {
113  grad_outputs->push_back(Diag(scope, grad_inputs[0]));
114  return scope.status();
115}
116REGISTER_GRADIENT_OP("DiagPart", DiagPartGrad);
117
118Status MatrixDiagGrad(const Scope& scope, const Operation& op,
119                      const std::vector<Output>& grad_inputs,
120                      std::vector<Output>* grad_outputs) {
121  grad_outputs->push_back(MatrixDiagPart(scope, grad_inputs[0]));
122  return scope.status();
123}
124REGISTER_GRADIENT_OP("MatrixDiag", MatrixDiagGrad);
125
126Status MatrixBandPartGrad(const Scope& scope, const Operation& op,
127                          const std::vector<Output>& grad_inputs,
128                          std::vector<Output>* grad_outputs) {
129  auto num_lower = op.input(1);
130  auto num_upper = op.input(2);
131  grad_outputs->push_back(
132      MatrixBandPart(scope, grad_inputs[0], num_lower, num_upper));
133  grad_outputs->push_back(NoGradient());
134  grad_outputs->push_back(NoGradient());
135  return scope.status();
136}
137REGISTER_GRADIENT_OP("MatrixBandPart", MatrixBandPartGrad);
138
139Status GatherNdGrad(const Scope& scope, const Operation& op,
140                    const std::vector<Output>& grad_inputs,
141                    std::vector<Output>* grad_outputs) {
142  auto ref = op.input(0);
143  auto ref_shape = Shape(scope, ref);
144  auto indices = op.input(1);
145  grad_outputs->push_back(ScatterNd(scope, indices, grad_inputs[0], ref_shape));
146  grad_outputs->push_back(NoGradient());
147  return scope.status();
148}
149REGISTER_GRADIENT_OP("GatherNd", GatherNdGrad);
150
151Status CheckNumericsGrad(const Scope& scope, const Operation& op,
152                         const std::vector<Output>& grad_inputs,
153                         std::vector<Output>* grad_outputs) {
154  grad_outputs->push_back(CheckNumerics(
155      scope, grad_inputs[0],
156      "Not a number (NaN) or infinity (Inf) values detected in gradient."));
157  return scope.status();
158}
159REGISTER_GRADIENT_OP("CheckNumerics", CheckNumericsGrad);
160
161Status ReshapeGrad(const Scope& scope, const Operation& op,
162                   const std::vector<Output>& grad_inputs,
163                   std::vector<Output>* grad_outputs) {
164  auto input_shape = Shape(scope, op.input(0));
165  grad_outputs->push_back(Reshape(scope, grad_inputs[0], input_shape));
166  grad_outputs->push_back(NoGradient());
167  return scope.status();
168}
169REGISTER_GRADIENT_OP("Reshape", ReshapeGrad);
170
171Status ExpandDimsGrad(const Scope& scope, const Operation& op,
172                      const std::vector<Output>& grad_inputs,
173                      std::vector<Output>* grad_outputs) {
174  auto input_shape = Shape(scope, op.input(0));
175  grad_outputs->push_back(Reshape(scope, grad_inputs[0], input_shape));
176  grad_outputs->push_back(NoGradient());
177  return scope.status();
178}
179REGISTER_GRADIENT_OP("ExpandDims", ExpandDimsGrad);
180
181Status SqueezeGrad(const Scope& scope, const Operation& op,
182                   const std::vector<Output>& grad_inputs,
183                   std::vector<Output>* grad_outputs) {
184  auto input_shape = Shape(scope, op.input(0));
185  grad_outputs->push_back(Reshape(scope, grad_inputs[0], input_shape));
186  return scope.status();
187}
188REGISTER_GRADIENT_OP("Squeeze", SqueezeGrad);
189
190Status TransposeGrad(const Scope& scope, const Operation& op,
191                     const std::vector<Output>& grad_inputs,
192                     std::vector<Output>* grad_outputs) {
193  auto inverted_perm = InvertPermutation(scope, op.input(1));
194  grad_outputs->push_back(Transpose(scope, grad_inputs[0], inverted_perm));
195  grad_outputs->push_back(NoGradient());
196  return scope.status();
197}
198REGISTER_GRADIENT_OP("Transpose", TransposeGrad);
199
200Status ReverseSequenceGrad(const Scope& scope, const Operation& op,
201                           const std::vector<Output>& grad_inputs,
202                           std::vector<Output>* grad_outputs) {
203  auto seq_lengths = op.input(1);
204  int batch_dim;
205  TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "batch_dim", &batch_dim));
206  int seq_dim;
207  TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "seq_dim", &seq_dim));
208  grad_outputs->push_back(
209      ReverseSequence(scope, grad_inputs[0], seq_lengths, seq_dim,
210                      ReverseSequence::BatchDim(batch_dim)));
211  grad_outputs->push_back(NoGradient());
212  return scope.status();
213}
214REGISTER_GRADIENT_OP("ReverseSequence", ReverseSequenceGrad);
215
216Status ReverseGrad(const Scope& scope, const Operation& op,
217                   const std::vector<Output>& grad_inputs,
218                   std::vector<Output>* grad_outputs) {
219  auto reverse_dims = op.input(1);
220  grad_outputs->push_back(Reverse(scope, grad_inputs[0], reverse_dims));
221  grad_outputs->push_back(NoGradient());
222  return scope.status();
223}
224REGISTER_GRADIENT_OP("ReverseV2", ReverseGrad);
225
226Status ScatterNdGrad(const Scope& scope, const Operation& op,
227                     const std::vector<Output>& grad_inputs,
228                     std::vector<Output>* grad_outputs) {
229  auto indices = op.input(0);
230  grad_outputs->push_back(NoGradient());
231  grad_outputs->push_back(GatherNd(scope, grad_inputs[0], indices));
232  grad_outputs->push_back(NoGradient());
233  return scope.status();
234}
235REGISTER_GRADIENT_OP("ScatterNd", ScatterNdGrad);
236
237Status PadGrad(const Scope& scope, const Operation& op,
238               const std::vector<Output>& grad_inputs,
239               std::vector<Output>* grad_outputs) {
240  auto x = op.input(0);
241  auto a = op.input(1);  // [Rank(x), 2]
242  // Takes a slice of a. The 1st column. [Rank(x), 1].
243  auto size = Stack(scope, {Rank(scope, x), 1});
244  auto pad_before = Slice(scope, a, {0, 0}, size);
245  // Make it a 1-D tensor.
246  auto begin = Reshape(scope, pad_before, {-1});
247  grad_outputs->push_back(Slice(scope, grad_inputs[0], begin, Shape(scope, x)));
248  grad_outputs->push_back(NoGradient());
249  return scope.status();
250}
251REGISTER_GRADIENT_OP("Pad", PadGrad);
252
253Status SpaceToBatchGrad(const Scope& scope, const Operation& op,
254                        const std::vector<Output>& grad_inputs,
255                        std::vector<Output>* grad_outputs) {
256  int block_size;
257  TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "block_size", &block_size));
258  grad_outputs->push_back(
259      BatchToSpace(scope, grad_inputs[0], op.input(1), block_size));
260  grad_outputs->push_back(NoGradient());
261  return scope.status();
262}
263REGISTER_GRADIENT_OP("SpaceToBatch", SpaceToBatchGrad);
264
265Status SpaceToBatchNDGrad(const Scope& scope, const Operation& op,
266                          const std::vector<Output>& grad_inputs,
267                          std::vector<Output>* grad_outputs) {
268  grad_outputs->push_back(
269      BatchToSpaceND(scope, grad_inputs[0], op.input(1), op.input(2)));
270  grad_outputs->push_back(NoGradient());
271  grad_outputs->push_back(NoGradient());
272  return scope.status();
273}
274REGISTER_GRADIENT_OP("SpaceToBatchND", SpaceToBatchNDGrad);
275
276Status BatchToSpaceGrad(const Scope& scope, const Operation& op,
277                        const std::vector<Output>& grad_inputs,
278                        std::vector<Output>* grad_outputs) {
279  int block_size;
280  TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "block_size", &block_size));
281  grad_outputs->push_back(
282      SpaceToBatch(scope, grad_inputs[0], op.input(1), block_size));
283  grad_outputs->push_back(NoGradient());
284  return scope.status();
285}
286REGISTER_GRADIENT_OP("BatchToSpace", BatchToSpaceGrad);
287
288Status BatchToSpaceNDGrad(const Scope& scope, const Operation& op,
289                          const std::vector<Output>& grad_inputs,
290                          std::vector<Output>* grad_outputs) {
291  grad_outputs->push_back(
292      SpaceToBatchND(scope, grad_inputs[0], op.input(1), op.input(2)));
293  grad_outputs->push_back(NoGradient());
294  grad_outputs->push_back(NoGradient());
295  return scope.status();
296}
297REGISTER_GRADIENT_OP("BatchToSpaceND", BatchToSpaceNDGrad);
298
299Status SpaceToDepthGrad(const Scope& scope, const Operation& op,
300                        const std::vector<Output>& grad_inputs,
301                        std::vector<Output>* grad_outputs) {
302  int block_size;
303  TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "block_size", &block_size));
304  grad_outputs->push_back(DepthToSpace(scope, grad_inputs[0], block_size));
305  return scope.status();
306}
307REGISTER_GRADIENT_OP("SpaceToDepth", SpaceToDepthGrad);
308
309Status DepthToSpaceGrad(const Scope& scope, const Operation& op,
310                        const std::vector<Output>& grad_inputs,
311                        std::vector<Output>* grad_outputs) {
312  int block_size;
313  TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "block_size", &block_size));
314  grad_outputs->push_back(SpaceToDepth(scope, grad_inputs[0], block_size));
315  return scope.status();
316}
317REGISTER_GRADIENT_OP("DepthToSpace", DepthToSpaceGrad);
318
319Status MirrorPadGrad(const Scope& scope, const Operation& op,
320                     const std::vector<Output>& grad_inputs,
321                     std::vector<Output>* grad_outputs) {
322  string mode;
323  TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "mode", &mode));
324  grad_outputs->push_back(tensorflow::ops::internal::MirrorPadGrad(
325      scope, grad_inputs[0], op.input(1), mode));
326  grad_outputs->push_back(NoGradient());
327  return scope.status();
328}
329REGISTER_GRADIENT_OP("MirrorPad", MirrorPadGrad);
330
331// TODO(suharshs): b/34770860. This gradient was within 1e-3 but not 1e-4.
332Status MirrorPadGradGrad(const Scope& scope, const Operation& op,
333                         const std::vector<Output>& grad_inputs,
334                         std::vector<Output>* grad_outputs) {
335  string mode;
336  TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "mode", &mode));
337  grad_outputs->push_back(MirrorPad(scope, grad_inputs[0], op.input(1), mode));
338  grad_outputs->push_back(NoGradient());
339  return scope.status();
340}
341REGISTER_GRADIENT_OP("MirrorPadGrad", MirrorPadGradGrad);
342
343}  // anonymous namespace
344}  // namespace ops
345}  // namespace tensorflow
346