array_grad.cc revision d3a41971ccc7c03e8b476a48dc8aa7fbb983431c
1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include <vector> 17 18#include "tensorflow/cc/ops/array_ops_internal.h" 19#include "tensorflow/cc/ops/standard_ops.h" 20 21#include "tensorflow/cc/framework/grad_op_registry.h" 22#include "tensorflow/cc/framework/gradients.h" 23 24namespace tensorflow { 25namespace ops { 26namespace { 27 28REGISTER_NO_GRADIENT_OP("Const"); 29REGISTER_NO_GRADIENT_OP("StopGradient"); 30REGISTER_NO_GRADIENT_OP("ConcatOffset"); 31REGISTER_NO_GRADIENT_OP("EditDistance"); 32REGISTER_NO_GRADIENT_OP("ZerosLike"); 33REGISTER_NO_GRADIENT_OP("InvertPermutation"); 34REGISTER_NO_GRADIENT_OP("Shape"); 35REGISTER_NO_GRADIENT_OP("ShapeN"); 36REGISTER_NO_GRADIENT_OP("Rank"); 37REGISTER_NO_GRADIENT_OP("Size"); 38REGISTER_NO_GRADIENT_OP("BroadcastGradientArgs"); 39REGISTER_NO_GRADIENT_OP("OneHot"); 40 41Status PackGrad(const Scope& scope, const Operation& op, 42 const std::vector<Output>& grad_inputs, 43 std::vector<Output>* grad_outputs) { 44 int N; 45 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "N", &N)); 46 int axis; 47 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "axis", &axis)); 48 49 grad_outputs->reserve(N); 50 auto grad_op = Unstack(scope, grad_inputs[0], N, Unstack::Axis(axis)); 51 for (const Output& o : grad_op.output) { 52 grad_outputs->emplace_back(o); 53 } 54 return scope.status(); 55} 56REGISTER_GRADIENT_OP("Pack", PackGrad); 57 58Status UnpackGrad(const Scope& scope, const Operation& op, 59 const std::vector<Output>& grad_inputs, 60 std::vector<Output>* grad_outputs) { 61 int axis; 62 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "axis", &axis)); 63 grad_outputs->push_back(Stack(scope, grad_inputs, Stack::Axis(axis))); 64 return scope.status(); 65} 66REGISTER_GRADIENT_OP("Unpack", UnpackGrad); 67 68Status IdentityGrad(const Scope& scope, const Operation& op, 69 const std::vector<Output>& grad_inputs, 70 std::vector<Output>* grad_outputs) { 71 grad_outputs->push_back(Identity(scope, grad_inputs[0])); 72 return scope.status(); 73} 74REGISTER_GRADIENT_OP("Identity", IdentityGrad); 75 76Status RefIdentityGrad(const Scope& scope, const Operation& op, 77 const std::vector<Output>& grad_inputs, 78 std::vector<Output>* grad_outputs) { 79 grad_outputs->push_back(Identity(scope, grad_inputs[0])); 80 return scope.status(); 81} 82REGISTER_GRADIENT_OP("RefIdentity", RefIdentityGrad); 83 84Status QuantizeAndDequantizeGrad(const Scope& scope, const Operation& op, 85 const std::vector<Output>& grad_inputs, 86 std::vector<Output>* grad_outputs) { 87 grad_outputs->push_back(Identity(scope, grad_inputs[0])); 88 return scope.status(); 89} 90REGISTER_GRADIENT_OP("QuantizeAndDequantize", QuantizeAndDequantizeGrad); 91REGISTER_GRADIENT_OP("QuantizeAndDequantizeV2", QuantizeAndDequantizeGrad); 92 93Status SplitGrad(const Scope& scope, const Operation& op, 94 const std::vector<Output>& grad_inputs, 95 std::vector<Output>* grad_outputs) { 96 grad_outputs->push_back(NoGradient()); 97 grad_outputs->push_back(Concat(scope, grad_inputs, op.input(0))); 98 return scope.status(); 99} 100REGISTER_GRADIENT_OP("Split", SplitGrad); 101 102Status DiagGrad(const Scope& scope, const Operation& op, 103 const std::vector<Output>& grad_inputs, 104 std::vector<Output>* grad_outputs) { 105 grad_outputs->push_back(DiagPart(scope, grad_inputs[0])); 106 return scope.status(); 107} 108REGISTER_GRADIENT_OP("Diag", DiagGrad); 109 110Status DiagPartGrad(const Scope& scope, const Operation& op, 111 const std::vector<Output>& grad_inputs, 112 std::vector<Output>* grad_outputs) { 113 grad_outputs->push_back(Diag(scope, grad_inputs[0])); 114 return scope.status(); 115} 116REGISTER_GRADIENT_OP("DiagPart", DiagPartGrad); 117 118Status MatrixDiagGrad(const Scope& scope, const Operation& op, 119 const std::vector<Output>& grad_inputs, 120 std::vector<Output>* grad_outputs) { 121 grad_outputs->push_back(MatrixDiagPart(scope, grad_inputs[0])); 122 return scope.status(); 123} 124REGISTER_GRADIENT_OP("MatrixDiag", MatrixDiagGrad); 125 126Status MatrixBandPartGrad(const Scope& scope, const Operation& op, 127 const std::vector<Output>& grad_inputs, 128 std::vector<Output>* grad_outputs) { 129 auto num_lower = op.input(1); 130 auto num_upper = op.input(2); 131 grad_outputs->push_back( 132 MatrixBandPart(scope, grad_inputs[0], num_lower, num_upper)); 133 grad_outputs->push_back(NoGradient()); 134 grad_outputs->push_back(NoGradient()); 135 return scope.status(); 136} 137REGISTER_GRADIENT_OP("MatrixBandPart", MatrixBandPartGrad); 138 139Status GatherNdGrad(const Scope& scope, const Operation& op, 140 const std::vector<Output>& grad_inputs, 141 std::vector<Output>* grad_outputs) { 142 auto ref = op.input(0); 143 auto ref_shape = Shape(scope, ref); 144 auto indices = op.input(1); 145 grad_outputs->push_back(ScatterNd(scope, indices, grad_inputs[0], ref_shape)); 146 grad_outputs->push_back(NoGradient()); 147 return scope.status(); 148} 149REGISTER_GRADIENT_OP("GatherNd", GatherNdGrad); 150 151Status CheckNumericsGrad(const Scope& scope, const Operation& op, 152 const std::vector<Output>& grad_inputs, 153 std::vector<Output>* grad_outputs) { 154 grad_outputs->push_back(CheckNumerics( 155 scope, grad_inputs[0], 156 "Not a number (NaN) or infinity (Inf) values detected in gradient.")); 157 return scope.status(); 158} 159REGISTER_GRADIENT_OP("CheckNumerics", CheckNumericsGrad); 160 161Status ReshapeGrad(const Scope& scope, const Operation& op, 162 const std::vector<Output>& grad_inputs, 163 std::vector<Output>* grad_outputs) { 164 auto input_shape = Shape(scope, op.input(0)); 165 grad_outputs->push_back(Reshape(scope, grad_inputs[0], input_shape)); 166 grad_outputs->push_back(NoGradient()); 167 return scope.status(); 168} 169REGISTER_GRADIENT_OP("Reshape", ReshapeGrad); 170 171Status ExpandDimsGrad(const Scope& scope, const Operation& op, 172 const std::vector<Output>& grad_inputs, 173 std::vector<Output>* grad_outputs) { 174 auto input_shape = Shape(scope, op.input(0)); 175 grad_outputs->push_back(Reshape(scope, grad_inputs[0], input_shape)); 176 grad_outputs->push_back(NoGradient()); 177 return scope.status(); 178} 179REGISTER_GRADIENT_OP("ExpandDims", ExpandDimsGrad); 180 181Status SqueezeGrad(const Scope& scope, const Operation& op, 182 const std::vector<Output>& grad_inputs, 183 std::vector<Output>* grad_outputs) { 184 auto input_shape = Shape(scope, op.input(0)); 185 grad_outputs->push_back(Reshape(scope, grad_inputs[0], input_shape)); 186 return scope.status(); 187} 188REGISTER_GRADIENT_OP("Squeeze", SqueezeGrad); 189 190Status TransposeGrad(const Scope& scope, const Operation& op, 191 const std::vector<Output>& grad_inputs, 192 std::vector<Output>* grad_outputs) { 193 auto inverted_perm = InvertPermutation(scope, op.input(1)); 194 grad_outputs->push_back(Transpose(scope, grad_inputs[0], inverted_perm)); 195 grad_outputs->push_back(NoGradient()); 196 return scope.status(); 197} 198REGISTER_GRADIENT_OP("Transpose", TransposeGrad); 199 200Status ReverseSequenceGrad(const Scope& scope, const Operation& op, 201 const std::vector<Output>& grad_inputs, 202 std::vector<Output>* grad_outputs) { 203 auto seq_lengths = op.input(1); 204 int batch_dim; 205 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "batch_dim", &batch_dim)); 206 int seq_dim; 207 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "seq_dim", &seq_dim)); 208 grad_outputs->push_back( 209 ReverseSequence(scope, grad_inputs[0], seq_lengths, seq_dim, 210 ReverseSequence::BatchDim(batch_dim))); 211 grad_outputs->push_back(NoGradient()); 212 return scope.status(); 213} 214REGISTER_GRADIENT_OP("ReverseSequence", ReverseSequenceGrad); 215 216Status ReverseGrad(const Scope& scope, const Operation& op, 217 const std::vector<Output>& grad_inputs, 218 std::vector<Output>* grad_outputs) { 219 auto reverse_dims = op.input(1); 220 grad_outputs->push_back(Reverse(scope, grad_inputs[0], reverse_dims)); 221 grad_outputs->push_back(NoGradient()); 222 return scope.status(); 223} 224REGISTER_GRADIENT_OP("ReverseV2", ReverseGrad); 225 226Status ScatterNdGrad(const Scope& scope, const Operation& op, 227 const std::vector<Output>& grad_inputs, 228 std::vector<Output>* grad_outputs) { 229 auto indices = op.input(0); 230 grad_outputs->push_back(NoGradient()); 231 grad_outputs->push_back(GatherNd(scope, grad_inputs[0], indices)); 232 grad_outputs->push_back(NoGradient()); 233 return scope.status(); 234} 235REGISTER_GRADIENT_OP("ScatterNd", ScatterNdGrad); 236 237Status PadGrad(const Scope& scope, const Operation& op, 238 const std::vector<Output>& grad_inputs, 239 std::vector<Output>* grad_outputs) { 240 auto x = op.input(0); 241 auto a = op.input(1); // [Rank(x), 2] 242 // Takes a slice of a. The 1st column. [Rank(x), 1]. 243 auto size = Stack(scope, {Rank(scope, x), 1}); 244 auto pad_before = Slice(scope, a, {0, 0}, size); 245 // Make it a 1-D tensor. 246 auto begin = Reshape(scope, pad_before, {-1}); 247 grad_outputs->push_back(Slice(scope, grad_inputs[0], begin, Shape(scope, x))); 248 grad_outputs->push_back(NoGradient()); 249 return scope.status(); 250} 251REGISTER_GRADIENT_OP("Pad", PadGrad); 252 253Status SpaceToBatchGrad(const Scope& scope, const Operation& op, 254 const std::vector<Output>& grad_inputs, 255 std::vector<Output>* grad_outputs) { 256 int block_size; 257 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "block_size", &block_size)); 258 grad_outputs->push_back( 259 BatchToSpace(scope, grad_inputs[0], op.input(1), block_size)); 260 grad_outputs->push_back(NoGradient()); 261 return scope.status(); 262} 263REGISTER_GRADIENT_OP("SpaceToBatch", SpaceToBatchGrad); 264 265Status SpaceToBatchNDGrad(const Scope& scope, const Operation& op, 266 const std::vector<Output>& grad_inputs, 267 std::vector<Output>* grad_outputs) { 268 grad_outputs->push_back( 269 BatchToSpaceND(scope, grad_inputs[0], op.input(1), op.input(2))); 270 grad_outputs->push_back(NoGradient()); 271 grad_outputs->push_back(NoGradient()); 272 return scope.status(); 273} 274REGISTER_GRADIENT_OP("SpaceToBatchND", SpaceToBatchNDGrad); 275 276Status BatchToSpaceGrad(const Scope& scope, const Operation& op, 277 const std::vector<Output>& grad_inputs, 278 std::vector<Output>* grad_outputs) { 279 int block_size; 280 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "block_size", &block_size)); 281 grad_outputs->push_back( 282 SpaceToBatch(scope, grad_inputs[0], op.input(1), block_size)); 283 grad_outputs->push_back(NoGradient()); 284 return scope.status(); 285} 286REGISTER_GRADIENT_OP("BatchToSpace", BatchToSpaceGrad); 287 288Status BatchToSpaceNDGrad(const Scope& scope, const Operation& op, 289 const std::vector<Output>& grad_inputs, 290 std::vector<Output>* grad_outputs) { 291 grad_outputs->push_back( 292 SpaceToBatchND(scope, grad_inputs[0], op.input(1), op.input(2))); 293 grad_outputs->push_back(NoGradient()); 294 grad_outputs->push_back(NoGradient()); 295 return scope.status(); 296} 297REGISTER_GRADIENT_OP("BatchToSpaceND", BatchToSpaceNDGrad); 298 299Status SpaceToDepthGrad(const Scope& scope, const Operation& op, 300 const std::vector<Output>& grad_inputs, 301 std::vector<Output>* grad_outputs) { 302 int block_size; 303 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "block_size", &block_size)); 304 grad_outputs->push_back(DepthToSpace(scope, grad_inputs[0], block_size)); 305 return scope.status(); 306} 307REGISTER_GRADIENT_OP("SpaceToDepth", SpaceToDepthGrad); 308 309Status DepthToSpaceGrad(const Scope& scope, const Operation& op, 310 const std::vector<Output>& grad_inputs, 311 std::vector<Output>* grad_outputs) { 312 int block_size; 313 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "block_size", &block_size)); 314 grad_outputs->push_back(SpaceToDepth(scope, grad_inputs[0], block_size)); 315 return scope.status(); 316} 317REGISTER_GRADIENT_OP("DepthToSpace", DepthToSpaceGrad); 318 319Status MirrorPadGrad(const Scope& scope, const Operation& op, 320 const std::vector<Output>& grad_inputs, 321 std::vector<Output>* grad_outputs) { 322 string mode; 323 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "mode", &mode)); 324 grad_outputs->push_back(tensorflow::ops::internal::MirrorPadGrad( 325 scope, grad_inputs[0], op.input(1), mode)); 326 grad_outputs->push_back(NoGradient()); 327 return scope.status(); 328} 329REGISTER_GRADIENT_OP("MirrorPad", MirrorPadGrad); 330 331// TODO(suharshs): b/34770860. This gradient was within 1e-3 but not 1e-4. 332Status MirrorPadGradGrad(const Scope& scope, const Operation& op, 333 const std::vector<Output>& grad_inputs, 334 std::vector<Output>* grad_outputs) { 335 string mode; 336 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "mode", &mode)); 337 grad_outputs->push_back(MirrorPad(scope, grad_inputs[0], op.input(1), mode)); 338 grad_outputs->push_back(NoGradient()); 339 return scope.status(); 340} 341REGISTER_GRADIENT_OP("MirrorPadGrad", MirrorPadGradGrad); 342 343} // anonymous namespace 344} // namespace ops 345} // namespace tensorflow 346