array_grad.cc revision 45ef73f0b5677af8d8e3c7ad276070f5a728c079
1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include <vector> 17 18#include "tensorflow/cc/ops/standard_ops.h" 19 20#include "tensorflow/cc/framework/grad_op_registry.h" 21#include "tensorflow/cc/framework/gradients.h" 22 23namespace tensorflow { 24namespace ops { 25namespace { 26 27REGISTER_NO_GRADIENT_OP("Const"); 28REGISTER_NO_GRADIENT_OP("StopGradient"); 29REGISTER_NO_GRADIENT_OP("ConcatOffset"); 30REGISTER_NO_GRADIENT_OP("EditDistance"); 31REGISTER_NO_GRADIENT_OP("ZerosLike"); 32REGISTER_NO_GRADIENT_OP("InvertPermutation"); 33REGISTER_NO_GRADIENT_OP("Shape"); 34REGISTER_NO_GRADIENT_OP("ShapeN"); 35REGISTER_NO_GRADIENT_OP("Rank"); 36REGISTER_NO_GRADIENT_OP("Size"); 37REGISTER_NO_GRADIENT_OP("BroadcastGradientArgs"); 38REGISTER_NO_GRADIENT_OP("OneHot"); 39 40Status PackGrad(const Scope& scope, const Operation& op, 41 const std::vector<Output>& grad_inputs, 42 std::vector<Output>* grad_outputs) { 43 int N; 44 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "N", &N)); 45 int axis; 46 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "axis", &axis)); 47 48 grad_outputs->reserve(N); 49 auto grad_op = Unpack(scope, grad_inputs[0], N, Unpack::Axis(axis)); 50 for (const Output& o : grad_op.output) { 51 grad_outputs->emplace_back(o); 52 } 53 return scope.status(); 54} 55REGISTER_GRADIENT_OP("Pack", PackGrad); 56 57Status UnpackGrad(const Scope& scope, const Operation& op, 58 const std::vector<Output>& grad_inputs, 59 std::vector<Output>* grad_outputs) { 60 int axis; 61 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "axis", &axis)); 62 grad_outputs->push_back(Pack(scope, grad_inputs, Pack::Axis(axis))); 63 return scope.status(); 64} 65REGISTER_GRADIENT_OP("Unpack", UnpackGrad); 66 67Status IdentityGrad(const Scope& scope, const Operation& op, 68 const std::vector<Output>& grad_inputs, 69 std::vector<Output>* grad_outputs) { 70 grad_outputs->push_back(Identity(scope, grad_inputs[0])); 71 return scope.status(); 72} 73REGISTER_GRADIENT_OP("Identity", IdentityGrad); 74 75Status RefIdentityGrad(const Scope& scope, const Operation& op, 76 const std::vector<Output>& grad_inputs, 77 std::vector<Output>* grad_outputs) { 78 grad_outputs->push_back(Identity(scope, grad_inputs[0])); 79 return scope.status(); 80} 81REGISTER_GRADIENT_OP("RefIdentity", RefIdentityGrad); 82 83Status QuantizeAndDequantizeGrad(const Scope& scope, const Operation& op, 84 const std::vector<Output>& grad_inputs, 85 std::vector<Output>* grad_outputs) { 86 grad_outputs->push_back(Identity(scope, grad_inputs[0])); 87 return scope.status(); 88} 89REGISTER_GRADIENT_OP("QuantizeAndDequantize", QuantizeAndDequantizeGrad); 90 91Status SplitGrad(const Scope& scope, const Operation& op, 92 const std::vector<Output>& grad_inputs, 93 std::vector<Output>* grad_outputs) { 94 grad_outputs->push_back(NoGradient()); 95 grad_outputs->push_back(Concat(scope, op.input(0), grad_inputs)); 96 return scope.status(); 97} 98REGISTER_GRADIENT_OP("Split", SplitGrad); 99 100Status DiagGrad(const Scope& scope, const Operation& op, 101 const std::vector<Output>& grad_inputs, 102 std::vector<Output>* grad_outputs) { 103 grad_outputs->push_back(DiagPart(scope, grad_inputs[0])); 104 return scope.status(); 105} 106REGISTER_GRADIENT_OP("Diag", DiagGrad); 107 108Status DiagPartGrad(const Scope& scope, const Operation& op, 109 const std::vector<Output>& grad_inputs, 110 std::vector<Output>* grad_outputs) { 111 grad_outputs->push_back(Diag(scope, grad_inputs[0])); 112 return scope.status(); 113} 114REGISTER_GRADIENT_OP("DiagPart", DiagPartGrad); 115 116Status MatrixDiagGrad(const Scope& scope, const Operation& op, 117 const std::vector<Output>& grad_inputs, 118 std::vector<Output>* grad_outputs) { 119 grad_outputs->push_back(MatrixDiagPart(scope, grad_inputs[0])); 120 return scope.status(); 121} 122REGISTER_GRADIENT_OP("MatrixDiag", MatrixDiagGrad); 123 124Status MatrixBandPartGrad(const Scope& scope, const Operation& op, 125 const std::vector<Output>& grad_inputs, 126 std::vector<Output>* grad_outputs) { 127 auto num_lower = op.input(1); 128 auto num_upper = op.input(2); 129 grad_outputs->push_back( 130 MatrixBandPart(scope, grad_inputs[0], num_lower, num_upper)); 131 grad_outputs->push_back(NoGradient()); 132 grad_outputs->push_back(NoGradient()); 133 return scope.status(); 134} 135REGISTER_GRADIENT_OP("MatrixBandPart", MatrixBandPartGrad); 136 137Status GatherNdGrad(const Scope& scope, const Operation& op, 138 const std::vector<Output>& grad_inputs, 139 std::vector<Output>* grad_outputs) { 140 auto ref = op.input(0); 141 auto ref_shape = Shape(scope, ref); 142 auto indices = op.input(1); 143 grad_outputs->push_back(ScatterNd(scope, indices, grad_inputs[0], ref_shape)); 144 grad_outputs->push_back(NoGradient()); 145 return scope.status(); 146} 147REGISTER_GRADIENT_OP("GatherNd", GatherNdGrad); 148 149Status CheckNumericsGrad(const Scope& scope, const Operation& op, 150 const std::vector<Output>& grad_inputs, 151 std::vector<Output>* grad_outputs) { 152 grad_outputs->push_back(CheckNumerics( 153 scope, grad_inputs[0], 154 "Not a number (NaN) or infinity (Inf) values detected in gradient.")); 155 return scope.status(); 156} 157REGISTER_GRADIENT_OP("CheckNumerics", CheckNumericsGrad); 158 159Status ReshapeGrad(const Scope& scope, const Operation& op, 160 const std::vector<Output>& grad_inputs, 161 std::vector<Output>* grad_outputs) { 162 auto input_shape = Shape(scope, op.input(0)); 163 grad_outputs->push_back(Reshape(scope, grad_inputs[0], input_shape)); 164 grad_outputs->push_back(NoGradient()); 165 return scope.status(); 166} 167REGISTER_GRADIENT_OP("Reshape", ReshapeGrad); 168 169Status ExpandDimsGrad(const Scope& scope, const Operation& op, 170 const std::vector<Output>& grad_inputs, 171 std::vector<Output>* grad_outputs) { 172 auto input_shape = Shape(scope, op.input(0)); 173 grad_outputs->push_back(Reshape(scope, grad_inputs[0], input_shape)); 174 grad_outputs->push_back(NoGradient()); 175 return scope.status(); 176} 177REGISTER_GRADIENT_OP("ExpandDims", ExpandDimsGrad); 178 179Status SqueezeGrad(const Scope& scope, const Operation& op, 180 const std::vector<Output>& grad_inputs, 181 std::vector<Output>* grad_outputs) { 182 auto input_shape = Shape(scope, op.input(0)); 183 grad_outputs->push_back(Reshape(scope, grad_inputs[0], input_shape)); 184 return scope.status(); 185} 186REGISTER_GRADIENT_OP("Squeeze", SqueezeGrad); 187 188Status TransposeGrad(const Scope& scope, const Operation& op, 189 const std::vector<Output>& grad_inputs, 190 std::vector<Output>* grad_outputs) { 191 auto inverted_perm = InvertPermutation(scope, op.input(1)); 192 grad_outputs->push_back(Transpose(scope, grad_inputs[0], inverted_perm)); 193 grad_outputs->push_back(NoGradient()); 194 return scope.status(); 195} 196REGISTER_GRADIENT_OP("Transpose", TransposeGrad); 197 198Status ReverseSequenceGrad(const Scope& scope, const Operation& op, 199 const std::vector<Output>& grad_inputs, 200 std::vector<Output>* grad_outputs) { 201 auto seq_lengths = op.input(1); 202 int batch_dim; 203 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "batch_dim", &batch_dim)); 204 int seq_dim; 205 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "seq_dim", &seq_dim)); 206 grad_outputs->push_back( 207 ReverseSequence(scope, grad_inputs[0], seq_lengths, seq_dim, 208 ReverseSequence::BatchDim(batch_dim))); 209 grad_outputs->push_back(NoGradient()); 210 return scope.status(); 211} 212REGISTER_GRADIENT_OP("ReverseSequence", ReverseSequenceGrad); 213 214Status ReverseGrad(const Scope& scope, const Operation& op, 215 const std::vector<Output>& grad_inputs, 216 std::vector<Output>* grad_outputs) { 217 auto reverse_dims = op.input(1); 218 grad_outputs->push_back(Reverse(scope, grad_inputs[0], reverse_dims)); 219 grad_outputs->push_back(NoGradient()); 220 return scope.status(); 221} 222REGISTER_GRADIENT_OP("Reverse", ReverseGrad); 223 224Status ScatterNdGrad(const Scope& scope, const Operation& op, 225 const std::vector<Output>& grad_inputs, 226 std::vector<Output>* grad_outputs) { 227 auto indices = op.input(0); 228 grad_outputs->push_back(NoGradient()); 229 grad_outputs->push_back(GatherNd(scope, grad_inputs[0], indices)); 230 grad_outputs->push_back(NoGradient()); 231 return scope.status(); 232} 233REGISTER_GRADIENT_OP("ScatterNd", ScatterNdGrad); 234 235Status PadGrad(const Scope& scope, const Operation& op, 236 const std::vector<Output>& grad_inputs, 237 std::vector<Output>* grad_outputs) { 238 auto x = op.input(0); 239 auto a = op.input(1); // [Rank(x), 2] 240 // Takes a slice of a. The 1st column. [Rank(x), 1]. 241 auto size = Pack(scope, {Rank(scope, x), 1}); 242 auto pad_before = Slice(scope, a, {0, 0}, size); 243 // Make it a 1-D tensor. 244 auto begin = Reshape(scope, pad_before, {-1}); 245 grad_outputs->push_back(Slice(scope, grad_inputs[0], begin, Shape(scope, x))); 246 grad_outputs->push_back(NoGradient()); 247 return scope.status(); 248} 249REGISTER_GRADIENT_OP("Pad", PadGrad); 250 251Status SpaceToBatchGrad(const Scope& scope, const Operation& op, 252 const std::vector<Output>& grad_inputs, 253 std::vector<Output>* grad_outputs) { 254 int block_size; 255 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "block_size", &block_size)); 256 grad_outputs->push_back( 257 BatchToSpace(scope, grad_inputs[0], op.input(1), block_size)); 258 grad_outputs->push_back(NoGradient()); 259 return scope.status(); 260} 261REGISTER_GRADIENT_OP("SpaceToBatch", SpaceToBatchGrad); 262 263Status SpaceToBatchNDGrad(const Scope& scope, const Operation& op, 264 const std::vector<Output>& grad_inputs, 265 std::vector<Output>* grad_outputs) { 266 grad_outputs->push_back( 267 BatchToSpaceND(scope, grad_inputs[0], op.input(1), op.input(2))); 268 grad_outputs->push_back(NoGradient()); 269 grad_outputs->push_back(NoGradient()); 270 return scope.status(); 271} 272REGISTER_GRADIENT_OP("SpaceToBatchND", SpaceToBatchNDGrad); 273 274Status BatchToSpaceGrad(const Scope& scope, const Operation& op, 275 const std::vector<Output>& grad_inputs, 276 std::vector<Output>* grad_outputs) { 277 int block_size; 278 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "block_size", &block_size)); 279 grad_outputs->push_back( 280 SpaceToBatch(scope, grad_inputs[0], op.input(1), block_size)); 281 grad_outputs->push_back(NoGradient()); 282 return scope.status(); 283} 284REGISTER_GRADIENT_OP("BatchToSpace", BatchToSpaceGrad); 285 286Status BatchToSpaceNDGrad(const Scope& scope, const Operation& op, 287 const std::vector<Output>& grad_inputs, 288 std::vector<Output>* grad_outputs) { 289 grad_outputs->push_back( 290 SpaceToBatchND(scope, grad_inputs[0], op.input(1), op.input(2))); 291 grad_outputs->push_back(NoGradient()); 292 grad_outputs->push_back(NoGradient()); 293 return scope.status(); 294} 295REGISTER_GRADIENT_OP("BatchToSpaceND", BatchToSpaceNDGrad); 296 297Status SpaceToDepthGrad(const Scope& scope, const Operation& op, 298 const std::vector<Output>& grad_inputs, 299 std::vector<Output>* grad_outputs) { 300 int block_size; 301 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "block_size", &block_size)); 302 grad_outputs->push_back(DepthToSpace(scope, grad_inputs[0], block_size)); 303 return scope.status(); 304} 305REGISTER_GRADIENT_OP("SpaceToDepth", SpaceToDepthGrad); 306 307Status DepthToSpaceGrad(const Scope& scope, const Operation& op, 308 const std::vector<Output>& grad_inputs, 309 std::vector<Output>* grad_outputs) { 310 int block_size; 311 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "block_size", &block_size)); 312 grad_outputs->push_back(SpaceToDepth(scope, grad_inputs[0], block_size)); 313 return scope.status(); 314} 315REGISTER_GRADIENT_OP("DepthToSpace", DepthToSpaceGrad); 316 317Status MirrorPadGrad(const Scope& scope, const Operation& op, 318 const std::vector<Output>& grad_inputs, 319 std::vector<Output>* grad_outputs) { 320 string mode; 321 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "mode", &mode)); 322 grad_outputs->push_back( 323 tensorflow::ops::MirrorPadGrad(scope, grad_inputs[0], op.input(1), mode)); 324 grad_outputs->push_back(NoGradient()); 325 return scope.status(); 326} 327REGISTER_GRADIENT_OP("MirrorPad", MirrorPadGrad); 328 329// TODO(suharshs): b/34770860. This gradient was within 1e-3 but not 1e-4. 330Status MirrorPadGradGrad(const Scope& scope, const Operation& op, 331 const std::vector<Output>& grad_inputs, 332 std::vector<Output>* grad_outputs) { 333 string mode; 334 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "mode", &mode)); 335 grad_outputs->push_back(MirrorPad(scope, grad_inputs[0], op.input(1), mode)); 336 grad_outputs->push_back(NoGradient()); 337 return scope.status(); 338} 339REGISTER_GRADIENT_OP("MirrorPadGrad", MirrorPadGradGrad); 340 341} // anonymous namespace 342} // namespace ops 343} // namespace tensorflow 344