1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include <vector> 17 18#include "tensorflow/cc/ops/array_ops_internal.h" 19#include "tensorflow/cc/ops/standard_ops.h" 20#include "tensorflow/core/lib/strings/strcat.h" 21 22#include "tensorflow/cc/framework/grad_op_registry.h" 23#include "tensorflow/cc/framework/gradients.h" 24 25namespace tensorflow { 26namespace ops { 27namespace { 28 29REGISTER_NO_GRADIENT_OP("Const"); 30REGISTER_NO_GRADIENT_OP("StopGradient"); 31REGISTER_NO_GRADIENT_OP("ConcatOffset"); 32REGISTER_NO_GRADIENT_OP("EditDistance"); 33REGISTER_NO_GRADIENT_OP("ZerosLike"); 34REGISTER_NO_GRADIENT_OP("InvertPermutation"); 35REGISTER_NO_GRADIENT_OP("Shape"); 36REGISTER_NO_GRADIENT_OP("ShapeN"); 37REGISTER_NO_GRADIENT_OP("Rank"); 38REGISTER_NO_GRADIENT_OP("Size"); 39REGISTER_NO_GRADIENT_OP("BroadcastGradientArgs"); 40REGISTER_NO_GRADIENT_OP("OneHot"); 41 42Status PackGrad(const Scope& scope, const Operation& op, 43 const std::vector<Output>& grad_inputs, 44 std::vector<Output>* grad_outputs) { 45 int N; 46 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "N", &N)); 47 int axis; 48 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "axis", &axis)); 49 50 grad_outputs->reserve(N); 51 auto grad_op = Unstack(scope, grad_inputs[0], N, Unstack::Axis(axis)); 52 for (const Output& o : grad_op.output) { 53 grad_outputs->emplace_back(o); 54 } 55 return scope.status(); 56} 57REGISTER_GRADIENT_OP("Pack", PackGrad); 58 59Status UnpackGrad(const Scope& scope, const Operation& op, 60 const std::vector<Output>& grad_inputs, 61 std::vector<Output>* grad_outputs) { 62 int axis; 63 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "axis", &axis)); 64 grad_outputs->push_back(Stack(scope, grad_inputs, Stack::Axis(axis))); 65 return scope.status(); 66} 67REGISTER_GRADIENT_OP("Unpack", UnpackGrad); 68 69Status IdentityGrad(const Scope& scope, const Operation& op, 70 const std::vector<Output>& grad_inputs, 71 std::vector<Output>* grad_outputs) { 72 grad_outputs->push_back(Identity(scope, grad_inputs[0])); 73 return scope.status(); 74} 75REGISTER_GRADIENT_OP("Identity", IdentityGrad); 76 77Status RefIdentityGrad(const Scope& scope, const Operation& op, 78 const std::vector<Output>& grad_inputs, 79 std::vector<Output>* grad_outputs) { 80 grad_outputs->push_back(Identity(scope, grad_inputs[0])); 81 return scope.status(); 82} 83REGISTER_GRADIENT_OP("RefIdentity", RefIdentityGrad); 84 85Status QuantizeAndDequantizeGrad(const Scope& scope, const Operation& op, 86 const std::vector<Output>& grad_inputs, 87 std::vector<Output>* grad_outputs) { 88 grad_outputs->push_back(Identity(scope, grad_inputs[0])); 89 return scope.status(); 90} 91REGISTER_GRADIENT_OP("QuantizeAndDequantize", QuantizeAndDequantizeGrad); 92 93Status QuantizeAndDequantizeV2Grad(const Scope& scope, const Operation& op, 94 const std::vector<Output>& grad_inputs, 95 std::vector<Output>* grad_outputs) { 96 grad_outputs->push_back(Identity(scope, grad_inputs[0])); 97 grad_outputs->push_back(NoGradient()); 98 grad_outputs->push_back(NoGradient()); 99 return scope.status(); 100} 101REGISTER_GRADIENT_OP("QuantizeAndDequantizeV2", QuantizeAndDequantizeV2Grad); 102 103Status QuantizeAndDequantizeV3Grad(const Scope& scope, const Operation& op, 104 const std::vector<Output>& grad_inputs, 105 std::vector<Output>* grad_outputs) { 106 grad_outputs->push_back(Identity(scope, grad_inputs[0])); 107 grad_outputs->push_back(NoGradient()); 108 grad_outputs->push_back(NoGradient()); 109 grad_outputs->push_back(NoGradient()); 110 return scope.status(); 111} 112REGISTER_GRADIENT_OP("QuantizeAndDequantizeV3", QuantizeAndDequantizeV3Grad); 113 114Status SplitGrad(const Scope& scope, const Operation& op, 115 const std::vector<Output>& grad_inputs, 116 std::vector<Output>* grad_outputs) { 117 grad_outputs->push_back(NoGradient()); 118 grad_outputs->push_back(Concat(scope, grad_inputs, op.input(0))); 119 return scope.status(); 120} 121REGISTER_GRADIENT_OP("Split", SplitGrad); 122 123Status DiagGrad(const Scope& scope, const Operation& op, 124 const std::vector<Output>& grad_inputs, 125 std::vector<Output>* grad_outputs) { 126 grad_outputs->push_back(DiagPart(scope, grad_inputs[0])); 127 return scope.status(); 128} 129REGISTER_GRADIENT_OP("Diag", DiagGrad); 130 131Status DiagPartGrad(const Scope& scope, const Operation& op, 132 const std::vector<Output>& grad_inputs, 133 std::vector<Output>* grad_outputs) { 134 grad_outputs->push_back(Diag(scope, grad_inputs[0])); 135 return scope.status(); 136} 137REGISTER_GRADIENT_OP("DiagPart", DiagPartGrad); 138 139Status MatrixDiagGrad(const Scope& scope, const Operation& op, 140 const std::vector<Output>& grad_inputs, 141 std::vector<Output>* grad_outputs) { 142 grad_outputs->push_back(MatrixDiagPart(scope, grad_inputs[0])); 143 return scope.status(); 144} 145REGISTER_GRADIENT_OP("MatrixDiag", MatrixDiagGrad); 146 147Status MatrixBandPartGrad(const Scope& scope, const Operation& op, 148 const std::vector<Output>& grad_inputs, 149 std::vector<Output>* grad_outputs) { 150 auto num_lower = op.input(1); 151 auto num_upper = op.input(2); 152 grad_outputs->push_back( 153 MatrixBandPart(scope, grad_inputs[0], num_lower, num_upper)); 154 grad_outputs->push_back(NoGradient()); 155 grad_outputs->push_back(NoGradient()); 156 return scope.status(); 157} 158REGISTER_GRADIENT_OP("MatrixBandPart", MatrixBandPartGrad); 159 160Status GatherNdGrad(const Scope& scope, const Operation& op, 161 const std::vector<Output>& grad_inputs, 162 std::vector<Output>* grad_outputs) { 163 auto ref = op.input(0); 164 auto ref_shape = Shape(scope, ref); 165 auto indices = op.input(1); 166 grad_outputs->push_back(ScatterNd(scope, indices, grad_inputs[0], ref_shape)); 167 grad_outputs->push_back(NoGradient()); 168 return scope.status(); 169} 170REGISTER_GRADIENT_OP("GatherNd", GatherNdGrad); 171 172Status CheckNumericsGrad(const Scope& scope, const Operation& op, 173 const std::vector<Output>& grad_inputs, 174 std::vector<Output>* grad_outputs) { 175 string message; 176 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "message", &message)); 177 string err_msg = strings::StrCat( 178 "Not a number (NaN) or infinity (Inf) values detected in gradient. ", 179 message); 180 grad_outputs->push_back(CheckNumerics(scope, grad_inputs[0], err_msg)); 181 return scope.status(); 182} 183REGISTER_GRADIENT_OP("CheckNumerics", CheckNumericsGrad); 184 185Status ReshapeGrad(const Scope& scope, const Operation& op, 186 const std::vector<Output>& grad_inputs, 187 std::vector<Output>* grad_outputs) { 188 auto input_shape = Shape(scope, op.input(0)); 189 grad_outputs->push_back(Reshape(scope, grad_inputs[0], input_shape)); 190 grad_outputs->push_back(NoGradient()); 191 return scope.status(); 192} 193REGISTER_GRADIENT_OP("Reshape", ReshapeGrad); 194 195Status ExpandDimsGrad(const Scope& scope, const Operation& op, 196 const std::vector<Output>& grad_inputs, 197 std::vector<Output>* grad_outputs) { 198 auto input_shape = Shape(scope, op.input(0)); 199 grad_outputs->push_back(Reshape(scope, grad_inputs[0], input_shape)); 200 grad_outputs->push_back(NoGradient()); 201 return scope.status(); 202} 203REGISTER_GRADIENT_OP("ExpandDims", ExpandDimsGrad); 204 205Status SqueezeGrad(const Scope& scope, const Operation& op, 206 const std::vector<Output>& grad_inputs, 207 std::vector<Output>* grad_outputs) { 208 auto input_shape = Shape(scope, op.input(0)); 209 grad_outputs->push_back(Reshape(scope, grad_inputs[0], input_shape)); 210 return scope.status(); 211} 212REGISTER_GRADIENT_OP("Squeeze", SqueezeGrad); 213 214Status TransposeGrad(const Scope& scope, const Operation& op, 215 const std::vector<Output>& grad_inputs, 216 std::vector<Output>* grad_outputs) { 217 auto inverted_perm = InvertPermutation(scope, op.input(1)); 218 grad_outputs->push_back(Transpose(scope, grad_inputs[0], inverted_perm)); 219 grad_outputs->push_back(NoGradient()); 220 return scope.status(); 221} 222REGISTER_GRADIENT_OP("Transpose", TransposeGrad); 223 224Status ReverseSequenceGrad(const Scope& scope, const Operation& op, 225 const std::vector<Output>& grad_inputs, 226 std::vector<Output>* grad_outputs) { 227 auto seq_lengths = op.input(1); 228 int batch_dim; 229 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "batch_dim", &batch_dim)); 230 int seq_dim; 231 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "seq_dim", &seq_dim)); 232 grad_outputs->push_back( 233 ReverseSequence(scope, grad_inputs[0], seq_lengths, seq_dim, 234 ReverseSequence::BatchDim(batch_dim))); 235 grad_outputs->push_back(NoGradient()); 236 return scope.status(); 237} 238REGISTER_GRADIENT_OP("ReverseSequence", ReverseSequenceGrad); 239 240Status ReverseGrad(const Scope& scope, const Operation& op, 241 const std::vector<Output>& grad_inputs, 242 std::vector<Output>* grad_outputs) { 243 auto reverse_dims = op.input(1); 244 grad_outputs->push_back(Reverse(scope, grad_inputs[0], reverse_dims)); 245 grad_outputs->push_back(NoGradient()); 246 return scope.status(); 247} 248REGISTER_GRADIENT_OP("ReverseV2", ReverseGrad); 249 250Status ScatterNdGrad(const Scope& scope, const Operation& op, 251 const std::vector<Output>& grad_inputs, 252 std::vector<Output>* grad_outputs) { 253 auto indices = op.input(0); 254 grad_outputs->push_back(NoGradient()); 255 grad_outputs->push_back(GatherNd(scope, grad_inputs[0], indices)); 256 grad_outputs->push_back(NoGradient()); 257 return scope.status(); 258} 259REGISTER_GRADIENT_OP("ScatterNd", ScatterNdGrad); 260 261Status ScatterNdNonAliasingAddGrad(const Scope& scope, const Operation& op, 262 const std::vector<Output>& grad_inputs, 263 std::vector<Output>* grad_outputs) { 264 auto indices = op.input(1); 265 grad_outputs->push_back(Identity(scope, grad_inputs[0])); 266 grad_outputs->push_back(NoGradient()); 267 grad_outputs->push_back(GatherNd(scope, grad_inputs[0], indices)); 268 return scope.status(); 269} 270REGISTER_GRADIENT_OP("ScatterNdNonAliasingAdd", ScatterNdNonAliasingAddGrad); 271 272template <bool IsPadV2> 273Status PadGrad(const Scope& scope, const Operation& op, 274 const std::vector<Output>& grad_inputs, 275 std::vector<Output>* grad_outputs) { 276 auto x = op.input(0); 277 auto a = op.input(1); // [Rank(x), 2] 278 // Takes a slice of a. The 1st column. [Rank(x), 1]. 279 auto size = Stack(scope, {Rank(scope, x), 1}); 280 auto pad_before = Slice(scope, a, {0, 0}, size); 281 // Make it a 1-D tensor. 282 auto begin = Reshape(scope, pad_before, {-1}); 283 grad_outputs->push_back(Slice(scope, grad_inputs[0], begin, Shape(scope, x))); 284 grad_outputs->push_back(NoGradient()); 285 // PadV2 adds a "constant_values" input. 286 if (IsPadV2) { 287 grad_outputs->push_back(NoGradient()); 288 } 289 return scope.status(); 290} 291REGISTER_GRADIENT_OP("Pad", PadGrad<false>); 292REGISTER_GRADIENT_OP("PadV2", PadGrad<true>); 293 294Status SpaceToBatchGrad(const Scope& scope, const Operation& op, 295 const std::vector<Output>& grad_inputs, 296 std::vector<Output>* grad_outputs) { 297 int block_size; 298 TF_RETURN_IF_ERROR( 299 GetNodeAttr(op.node()->attrs(), "block_size", &block_size)); 300 grad_outputs->push_back( 301 BatchToSpace(scope, grad_inputs[0], op.input(1), block_size)); 302 grad_outputs->push_back(NoGradient()); 303 return scope.status(); 304} 305REGISTER_GRADIENT_OP("SpaceToBatch", SpaceToBatchGrad); 306 307Status SpaceToBatchNDGrad(const Scope& scope, const Operation& op, 308 const std::vector<Output>& grad_inputs, 309 std::vector<Output>* grad_outputs) { 310 grad_outputs->push_back( 311 BatchToSpaceND(scope, grad_inputs[0], op.input(1), op.input(2))); 312 grad_outputs->push_back(NoGradient()); 313 grad_outputs->push_back(NoGradient()); 314 return scope.status(); 315} 316REGISTER_GRADIENT_OP("SpaceToBatchND", SpaceToBatchNDGrad); 317 318Status BatchToSpaceGrad(const Scope& scope, const Operation& op, 319 const std::vector<Output>& grad_inputs, 320 std::vector<Output>* grad_outputs) { 321 int block_size; 322 TF_RETURN_IF_ERROR( 323 GetNodeAttr(op.node()->attrs(), "block_size", &block_size)); 324 grad_outputs->push_back( 325 SpaceToBatch(scope, grad_inputs[0], op.input(1), block_size)); 326 grad_outputs->push_back(NoGradient()); 327 return scope.status(); 328} 329REGISTER_GRADIENT_OP("BatchToSpace", BatchToSpaceGrad); 330 331Status BatchToSpaceNDGrad(const Scope& scope, const Operation& op, 332 const std::vector<Output>& grad_inputs, 333 std::vector<Output>* grad_outputs) { 334 grad_outputs->push_back( 335 SpaceToBatchND(scope, grad_inputs[0], op.input(1), op.input(2))); 336 grad_outputs->push_back(NoGradient()); 337 grad_outputs->push_back(NoGradient()); 338 return scope.status(); 339} 340REGISTER_GRADIENT_OP("BatchToSpaceND", BatchToSpaceNDGrad); 341 342Status SpaceToDepthGrad(const Scope& scope, const Operation& op, 343 const std::vector<Output>& grad_inputs, 344 std::vector<Output>* grad_outputs) { 345 int block_size; 346 TF_RETURN_IF_ERROR( 347 GetNodeAttr(op.node()->attrs(), "block_size", &block_size)); 348 grad_outputs->push_back(DepthToSpace(scope, grad_inputs[0], block_size)); 349 return scope.status(); 350} 351REGISTER_GRADIENT_OP("SpaceToDepth", SpaceToDepthGrad); 352 353Status DepthToSpaceGrad(const Scope& scope, const Operation& op, 354 const std::vector<Output>& grad_inputs, 355 std::vector<Output>* grad_outputs) { 356 int block_size; 357 TF_RETURN_IF_ERROR( 358 GetNodeAttr(op.node()->attrs(), "block_size", &block_size)); 359 grad_outputs->push_back(SpaceToDepth(scope, grad_inputs[0], block_size)); 360 return scope.status(); 361} 362REGISTER_GRADIENT_OP("DepthToSpace", DepthToSpaceGrad); 363 364Status MirrorPadGrad(const Scope& scope, const Operation& op, 365 const std::vector<Output>& grad_inputs, 366 std::vector<Output>* grad_outputs) { 367 string mode; 368 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "mode", &mode)); 369 grad_outputs->push_back(tensorflow::ops::internal::MirrorPadGrad( 370 scope, grad_inputs[0], op.input(1), mode)); 371 grad_outputs->push_back(NoGradient()); 372 return scope.status(); 373} 374REGISTER_GRADIENT_OP("MirrorPad", MirrorPadGrad); 375 376// TODO(suharshs): b/34770860. This gradient was within 1e-3 but not 1e-4. 377Status MirrorPadGradGrad(const Scope& scope, const Operation& op, 378 const std::vector<Output>& grad_inputs, 379 std::vector<Output>* grad_outputs) { 380 string mode; 381 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "mode", &mode)); 382 grad_outputs->push_back(MirrorPad(scope, grad_inputs[0], op.input(1), mode)); 383 grad_outputs->push_back(NoGradient()); 384 return scope.status(); 385} 386REGISTER_GRADIENT_OP("MirrorPadGrad", MirrorPadGradGrad); 387 388} // anonymous namespace 389} // namespace ops 390} // namespace tensorflow 391