array_grad.cc revision 45ef73f0b5677af8d8e3c7ad276070f5a728c079
1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <vector>
17
18#include "tensorflow/cc/ops/standard_ops.h"
19
20#include "tensorflow/cc/framework/grad_op_registry.h"
21#include "tensorflow/cc/framework/gradients.h"
22
23namespace tensorflow {
24namespace ops {
25namespace {
26
27REGISTER_NO_GRADIENT_OP("Const");
28REGISTER_NO_GRADIENT_OP("StopGradient");
29REGISTER_NO_GRADIENT_OP("ConcatOffset");
30REGISTER_NO_GRADIENT_OP("EditDistance");
31REGISTER_NO_GRADIENT_OP("ZerosLike");
32REGISTER_NO_GRADIENT_OP("InvertPermutation");
33REGISTER_NO_GRADIENT_OP("Shape");
34REGISTER_NO_GRADIENT_OP("ShapeN");
35REGISTER_NO_GRADIENT_OP("Rank");
36REGISTER_NO_GRADIENT_OP("Size");
37REGISTER_NO_GRADIENT_OP("BroadcastGradientArgs");
38REGISTER_NO_GRADIENT_OP("OneHot");
39
40Status PackGrad(const Scope& scope, const Operation& op,
41                const std::vector<Output>& grad_inputs,
42                std::vector<Output>* grad_outputs) {
43  int N;
44  TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "N", &N));
45  int axis;
46  TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "axis", &axis));
47
48  grad_outputs->reserve(N);
49  auto grad_op = Unpack(scope, grad_inputs[0], N, Unpack::Axis(axis));
50  for (const Output& o : grad_op.output) {
51    grad_outputs->emplace_back(o);
52  }
53  return scope.status();
54}
55REGISTER_GRADIENT_OP("Pack", PackGrad);
56
57Status UnpackGrad(const Scope& scope, const Operation& op,
58                  const std::vector<Output>& grad_inputs,
59                  std::vector<Output>* grad_outputs) {
60  int axis;
61  TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "axis", &axis));
62  grad_outputs->push_back(Pack(scope, grad_inputs, Pack::Axis(axis)));
63  return scope.status();
64}
65REGISTER_GRADIENT_OP("Unpack", UnpackGrad);
66
67Status IdentityGrad(const Scope& scope, const Operation& op,
68                    const std::vector<Output>& grad_inputs,
69                    std::vector<Output>* grad_outputs) {
70  grad_outputs->push_back(Identity(scope, grad_inputs[0]));
71  return scope.status();
72}
73REGISTER_GRADIENT_OP("Identity", IdentityGrad);
74
75Status RefIdentityGrad(const Scope& scope, const Operation& op,
76                       const std::vector<Output>& grad_inputs,
77                       std::vector<Output>* grad_outputs) {
78  grad_outputs->push_back(Identity(scope, grad_inputs[0]));
79  return scope.status();
80}
81REGISTER_GRADIENT_OP("RefIdentity", RefIdentityGrad);
82
83Status QuantizeAndDequantizeGrad(const Scope& scope, const Operation& op,
84                                 const std::vector<Output>& grad_inputs,
85                                 std::vector<Output>* grad_outputs) {
86  grad_outputs->push_back(Identity(scope, grad_inputs[0]));
87  return scope.status();
88}
89REGISTER_GRADIENT_OP("QuantizeAndDequantize", QuantizeAndDequantizeGrad);
90
91Status SplitGrad(const Scope& scope, const Operation& op,
92                 const std::vector<Output>& grad_inputs,
93                 std::vector<Output>* grad_outputs) {
94  grad_outputs->push_back(NoGradient());
95  grad_outputs->push_back(Concat(scope, op.input(0), grad_inputs));
96  return scope.status();
97}
98REGISTER_GRADIENT_OP("Split", SplitGrad);
99
100Status DiagGrad(const Scope& scope, const Operation& op,
101                const std::vector<Output>& grad_inputs,
102                std::vector<Output>* grad_outputs) {
103  grad_outputs->push_back(DiagPart(scope, grad_inputs[0]));
104  return scope.status();
105}
106REGISTER_GRADIENT_OP("Diag", DiagGrad);
107
108Status DiagPartGrad(const Scope& scope, const Operation& op,
109                    const std::vector<Output>& grad_inputs,
110                    std::vector<Output>* grad_outputs) {
111  grad_outputs->push_back(Diag(scope, grad_inputs[0]));
112  return scope.status();
113}
114REGISTER_GRADIENT_OP("DiagPart", DiagPartGrad);
115
116Status MatrixDiagGrad(const Scope& scope, const Operation& op,
117                      const std::vector<Output>& grad_inputs,
118                      std::vector<Output>* grad_outputs) {
119  grad_outputs->push_back(MatrixDiagPart(scope, grad_inputs[0]));
120  return scope.status();
121}
122REGISTER_GRADIENT_OP("MatrixDiag", MatrixDiagGrad);
123
124Status MatrixBandPartGrad(const Scope& scope, const Operation& op,
125                          const std::vector<Output>& grad_inputs,
126                          std::vector<Output>* grad_outputs) {
127  auto num_lower = op.input(1);
128  auto num_upper = op.input(2);
129  grad_outputs->push_back(
130      MatrixBandPart(scope, grad_inputs[0], num_lower, num_upper));
131  grad_outputs->push_back(NoGradient());
132  grad_outputs->push_back(NoGradient());
133  return scope.status();
134}
135REGISTER_GRADIENT_OP("MatrixBandPart", MatrixBandPartGrad);
136
137Status GatherNdGrad(const Scope& scope, const Operation& op,
138                    const std::vector<Output>& grad_inputs,
139                    std::vector<Output>* grad_outputs) {
140  auto ref = op.input(0);
141  auto ref_shape = Shape(scope, ref);
142  auto indices = op.input(1);
143  grad_outputs->push_back(ScatterNd(scope, indices, grad_inputs[0], ref_shape));
144  grad_outputs->push_back(NoGradient());
145  return scope.status();
146}
147REGISTER_GRADIENT_OP("GatherNd", GatherNdGrad);
148
149Status CheckNumericsGrad(const Scope& scope, const Operation& op,
150                         const std::vector<Output>& grad_inputs,
151                         std::vector<Output>* grad_outputs) {
152  grad_outputs->push_back(CheckNumerics(
153      scope, grad_inputs[0],
154      "Not a number (NaN) or infinity (Inf) values detected in gradient."));
155  return scope.status();
156}
157REGISTER_GRADIENT_OP("CheckNumerics", CheckNumericsGrad);
158
159Status ReshapeGrad(const Scope& scope, const Operation& op,
160                   const std::vector<Output>& grad_inputs,
161                   std::vector<Output>* grad_outputs) {
162  auto input_shape = Shape(scope, op.input(0));
163  grad_outputs->push_back(Reshape(scope, grad_inputs[0], input_shape));
164  grad_outputs->push_back(NoGradient());
165  return scope.status();
166}
167REGISTER_GRADIENT_OP("Reshape", ReshapeGrad);
168
169Status ExpandDimsGrad(const Scope& scope, const Operation& op,
170                      const std::vector<Output>& grad_inputs,
171                      std::vector<Output>* grad_outputs) {
172  auto input_shape = Shape(scope, op.input(0));
173  grad_outputs->push_back(Reshape(scope, grad_inputs[0], input_shape));
174  grad_outputs->push_back(NoGradient());
175  return scope.status();
176}
177REGISTER_GRADIENT_OP("ExpandDims", ExpandDimsGrad);
178
179Status SqueezeGrad(const Scope& scope, const Operation& op,
180                   const std::vector<Output>& grad_inputs,
181                   std::vector<Output>* grad_outputs) {
182  auto input_shape = Shape(scope, op.input(0));
183  grad_outputs->push_back(Reshape(scope, grad_inputs[0], input_shape));
184  return scope.status();
185}
186REGISTER_GRADIENT_OP("Squeeze", SqueezeGrad);
187
188Status TransposeGrad(const Scope& scope, const Operation& op,
189                     const std::vector<Output>& grad_inputs,
190                     std::vector<Output>* grad_outputs) {
191  auto inverted_perm = InvertPermutation(scope, op.input(1));
192  grad_outputs->push_back(Transpose(scope, grad_inputs[0], inverted_perm));
193  grad_outputs->push_back(NoGradient());
194  return scope.status();
195}
196REGISTER_GRADIENT_OP("Transpose", TransposeGrad);
197
198Status ReverseSequenceGrad(const Scope& scope, const Operation& op,
199                           const std::vector<Output>& grad_inputs,
200                           std::vector<Output>* grad_outputs) {
201  auto seq_lengths = op.input(1);
202  int batch_dim;
203  TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "batch_dim", &batch_dim));
204  int seq_dim;
205  TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "seq_dim", &seq_dim));
206  grad_outputs->push_back(
207      ReverseSequence(scope, grad_inputs[0], seq_lengths, seq_dim,
208                      ReverseSequence::BatchDim(batch_dim)));
209  grad_outputs->push_back(NoGradient());
210  return scope.status();
211}
212REGISTER_GRADIENT_OP("ReverseSequence", ReverseSequenceGrad);
213
214Status ReverseGrad(const Scope& scope, const Operation& op,
215                   const std::vector<Output>& grad_inputs,
216                   std::vector<Output>* grad_outputs) {
217  auto reverse_dims = op.input(1);
218  grad_outputs->push_back(Reverse(scope, grad_inputs[0], reverse_dims));
219  grad_outputs->push_back(NoGradient());
220  return scope.status();
221}
222REGISTER_GRADIENT_OP("Reverse", ReverseGrad);
223
224Status ScatterNdGrad(const Scope& scope, const Operation& op,
225                     const std::vector<Output>& grad_inputs,
226                     std::vector<Output>* grad_outputs) {
227  auto indices = op.input(0);
228  grad_outputs->push_back(NoGradient());
229  grad_outputs->push_back(GatherNd(scope, grad_inputs[0], indices));
230  grad_outputs->push_back(NoGradient());
231  return scope.status();
232}
233REGISTER_GRADIENT_OP("ScatterNd", ScatterNdGrad);
234
235Status PadGrad(const Scope& scope, const Operation& op,
236               const std::vector<Output>& grad_inputs,
237               std::vector<Output>* grad_outputs) {
238  auto x = op.input(0);
239  auto a = op.input(1);  // [Rank(x), 2]
240  // Takes a slice of a. The 1st column. [Rank(x), 1].
241  auto size = Pack(scope, {Rank(scope, x), 1});
242  auto pad_before = Slice(scope, a, {0, 0}, size);
243  // Make it a 1-D tensor.
244  auto begin = Reshape(scope, pad_before, {-1});
245  grad_outputs->push_back(Slice(scope, grad_inputs[0], begin, Shape(scope, x)));
246  grad_outputs->push_back(NoGradient());
247  return scope.status();
248}
249REGISTER_GRADIENT_OP("Pad", PadGrad);
250
251Status SpaceToBatchGrad(const Scope& scope, const Operation& op,
252                        const std::vector<Output>& grad_inputs,
253                        std::vector<Output>* grad_outputs) {
254  int block_size;
255  TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "block_size", &block_size));
256  grad_outputs->push_back(
257      BatchToSpace(scope, grad_inputs[0], op.input(1), block_size));
258  grad_outputs->push_back(NoGradient());
259  return scope.status();
260}
261REGISTER_GRADIENT_OP("SpaceToBatch", SpaceToBatchGrad);
262
263Status SpaceToBatchNDGrad(const Scope& scope, const Operation& op,
264                          const std::vector<Output>& grad_inputs,
265                          std::vector<Output>* grad_outputs) {
266  grad_outputs->push_back(
267      BatchToSpaceND(scope, grad_inputs[0], op.input(1), op.input(2)));
268  grad_outputs->push_back(NoGradient());
269  grad_outputs->push_back(NoGradient());
270  return scope.status();
271}
272REGISTER_GRADIENT_OP("SpaceToBatchND", SpaceToBatchNDGrad);
273
274Status BatchToSpaceGrad(const Scope& scope, const Operation& op,
275                        const std::vector<Output>& grad_inputs,
276                        std::vector<Output>* grad_outputs) {
277  int block_size;
278  TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "block_size", &block_size));
279  grad_outputs->push_back(
280      SpaceToBatch(scope, grad_inputs[0], op.input(1), block_size));
281  grad_outputs->push_back(NoGradient());
282  return scope.status();
283}
284REGISTER_GRADIENT_OP("BatchToSpace", BatchToSpaceGrad);
285
286Status BatchToSpaceNDGrad(const Scope& scope, const Operation& op,
287                          const std::vector<Output>& grad_inputs,
288                          std::vector<Output>* grad_outputs) {
289  grad_outputs->push_back(
290      SpaceToBatchND(scope, grad_inputs[0], op.input(1), op.input(2)));
291  grad_outputs->push_back(NoGradient());
292  grad_outputs->push_back(NoGradient());
293  return scope.status();
294}
295REGISTER_GRADIENT_OP("BatchToSpaceND", BatchToSpaceNDGrad);
296
297Status SpaceToDepthGrad(const Scope& scope, const Operation& op,
298                        const std::vector<Output>& grad_inputs,
299                        std::vector<Output>* grad_outputs) {
300  int block_size;
301  TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "block_size", &block_size));
302  grad_outputs->push_back(DepthToSpace(scope, grad_inputs[0], block_size));
303  return scope.status();
304}
305REGISTER_GRADIENT_OP("SpaceToDepth", SpaceToDepthGrad);
306
307Status DepthToSpaceGrad(const Scope& scope, const Operation& op,
308                        const std::vector<Output>& grad_inputs,
309                        std::vector<Output>* grad_outputs) {
310  int block_size;
311  TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "block_size", &block_size));
312  grad_outputs->push_back(SpaceToDepth(scope, grad_inputs[0], block_size));
313  return scope.status();
314}
315REGISTER_GRADIENT_OP("DepthToSpace", DepthToSpaceGrad);
316
317Status MirrorPadGrad(const Scope& scope, const Operation& op,
318                     const std::vector<Output>& grad_inputs,
319                     std::vector<Output>* grad_outputs) {
320  string mode;
321  TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "mode", &mode));
322  grad_outputs->push_back(
323      tensorflow::ops::MirrorPadGrad(scope, grad_inputs[0], op.input(1), mode));
324  grad_outputs->push_back(NoGradient());
325  return scope.status();
326}
327REGISTER_GRADIENT_OP("MirrorPad", MirrorPadGrad);
328
329// TODO(suharshs): b/34770860. This gradient was within 1e-3 but not 1e-4.
330Status MirrorPadGradGrad(const Scope& scope, const Operation& op,
331                         const std::vector<Output>& grad_inputs,
332                         std::vector<Output>* grad_outputs) {
333  string mode;
334  TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "mode", &mode));
335  grad_outputs->push_back(MirrorPad(scope, grad_inputs[0], op.input(1), mode));
336  grad_outputs->push_back(NoGradient());
337  return scope.status();
338}
339REGISTER_GRADIENT_OP("MirrorPadGrad", MirrorPadGradGrad);
340
341}  // anonymous namespace
342}  // namespace ops
343}  // namespace tensorflow
344