math_grad.cc revision 5a1d6d9dac79b46f055462ee52125753524d9f6e
1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include "tensorflow/cc/ops/array_ops_internal.h" 17#include "tensorflow/cc/ops/math_ops_internal.h" 18#include "tensorflow/cc/ops/standard_ops.h" 19 20#include "tensorflow/cc/framework/grad_op_registry.h" 21#include "tensorflow/cc/framework/gradients.h" 22 23namespace tensorflow { 24namespace ops { 25namespace { 26 27// Logical operations have no gradients. 28REGISTER_NO_GRADIENT_OP("Less"); 29REGISTER_NO_GRADIENT_OP("LessEqual"); 30REGISTER_NO_GRADIENT_OP("Greater"); 31REGISTER_NO_GRADIENT_OP("GreaterEqual"); 32REGISTER_NO_GRADIENT_OP("Equal"); 33REGISTER_NO_GRADIENT_OP("ApproximateEqual"); 34REGISTER_NO_GRADIENT_OP("NotEqual"); 35REGISTER_NO_GRADIENT_OP("LogicalAnd"); 36REGISTER_NO_GRADIENT_OP("LogicalOr"); 37REGISTER_NO_GRADIENT_OP("LogicalNot"); 38 39// Conjugate helper function returns the conjugate of an Output if it 40// is complex valued. 41Output ConjugateHelper(const Scope& scope, const Output& out) { 42 DataType dtype = out.type(); 43 if (dtype == DT_COMPLEX64 || dtype == DT_COMPLEX128) { 44 return Conj(scope, out); 45 } else { 46 return out; 47 } 48} 49 50// TODO(andydavis) Add control dependencies to gradient functions (as needed). 51 52Status AbsGrad(const Scope& scope, const Operation& op, 53 const std::vector<Output>& grad_inputs, 54 std::vector<Output>* grad_outputs) { 55 // dx = dy * sign(x) 56 grad_outputs->push_back(Mul(scope, grad_inputs[0], Sign(scope, op.input(0)))); 57 return scope.status(); 58} 59REGISTER_GRADIENT_OP("Abs", AbsGrad); 60 61Status NegGrad(const Scope& scope, const Operation& op, 62 const std::vector<Output>& grad_inputs, 63 std::vector<Output>* grad_outputs) { 64 // dx = -dy; 65 grad_outputs->push_back(Neg(scope, grad_inputs[0])); 66 return scope.status(); 67} 68REGISTER_GRADIENT_OP("Neg", NegGrad); 69 70Status InvGrad(const Scope& scope, const Operation& op, 71 const std::vector<Output>& grad_inputs, 72 std::vector<Output>* grad_outputs) { 73 // Use the built-in operator. 74 grad_outputs->push_back( 75 internal::ReciprocalGrad(scope, op.output(0), grad_inputs[0])); 76 return scope.status(); 77} 78REGISTER_GRADIENT_OP("Inv", InvGrad); 79REGISTER_GRADIENT_OP("Reciprocal", InvGrad); 80 81Status SquareGrad(const Scope& scope, const Operation& op, 82 const std::vector<Output>& grad_inputs, 83 std::vector<Output>* grad_outputs) { 84 // dy/dx = (2 * x) 85 auto two = Cast(scope, Const(scope, 2), op.input(0).type()); 86 auto dydx = Mul(scope, two, op.input(0)); 87 // grad(x) = grad(y) * conj(dy/dx) 88 grad_outputs->push_back( 89 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 90 return scope.status(); 91} 92REGISTER_GRADIENT_OP("Square", SquareGrad); 93 94Status SqrtGrad(const Scope& scope, const Operation& op, 95 const std::vector<Output>& grad_inputs, 96 std::vector<Output>* grad_outputs) { 97 // Use the built-in operator. 98 grad_outputs->push_back( 99 internal::SqrtGrad(scope, op.output(0), grad_inputs[0])); 100 return scope.status(); 101} 102REGISTER_GRADIENT_OP("Sqrt", SqrtGrad); 103 104Status RsqrtGrad(const Scope& scope, const Operation& op, 105 const std::vector<Output>& grad_inputs, 106 std::vector<Output>* grad_outputs) { 107 // Use the built-in operator. 108 grad_outputs->push_back( 109 internal::RsqrtGrad(scope, op.output(0), grad_inputs[0])); 110 return scope.status(); 111} 112REGISTER_GRADIENT_OP("Rsqrt", RsqrtGrad); 113 114Status ExpGrad(const Scope& scope, const Operation& op, 115 const std::vector<Output>& grad_inputs, 116 std::vector<Output>* grad_outputs) { 117 // dy/dx = exp(x) = y 118 // grad(x) = grad(y) * conj(dy/dx) 119 // = grad(y) * conj(y) 120 grad_outputs->push_back( 121 Mul(scope, grad_inputs[0], ConjugateHelper(scope, op.output(0)))); 122 return scope.status(); 123} 124REGISTER_GRADIENT_OP("Exp", ExpGrad); 125 126Status Expm1Grad(const Scope& scope, const Operation& op, 127 const std::vector<Output>& grad_inputs, 128 std::vector<Output>* grad_outputs) { 129 // y = expm1(x) 130 // dy/dx = exp(x) 131 auto dydx = Exp(scope, op.input(0)); 132 // grad(x) = grad(y) * conj(dy/dx) 133 grad_outputs->push_back( 134 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 135 return scope.status(); 136} 137REGISTER_GRADIENT_OP("Expm1", Expm1Grad); 138 139Status LogGrad(const Scope& scope, const Operation& op, 140 const std::vector<Output>& grad_inputs, 141 std::vector<Output>* grad_outputs) { 142 // y = log(x) 143 // dy/dx = 1 / x 144 auto dydx = Reciprocal(scope, op.input(0)); 145 // grad(x) = grad(y) * conj(dy/dx) 146 grad_outputs->push_back( 147 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 148 return scope.status(); 149} 150REGISTER_GRADIENT_OP("Log", LogGrad); 151 152Status Log1pGrad(const Scope& scope, const Operation& op, 153 const std::vector<Output>& grad_inputs, 154 std::vector<Output>* grad_outputs) { 155 // y = log1p(x) 156 // dy/dx = 1 / (1 + x) 157 auto one = Cast(scope, Const(scope, 1.0), op.input(0).type()); 158 auto dydx = Reciprocal(scope, Add(scope, one, op.input(0))); 159 // grad(x) = grad(y) * conj(dy/dx) 160 grad_outputs->push_back( 161 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 162 return scope.status(); 163} 164REGISTER_GRADIENT_OP("Log1p", Log1pGrad); 165 166Status SinhGrad(const Scope& scope, const Operation& op, 167 const std::vector<Output>& grad_inputs, 168 std::vector<Output>* grad_outputs) { 169 // y = sinh(x) 170 // dy/dx = cosh(x) 171 auto dydx = Cosh(scope, op.input(0)); 172 // grad(x) = grad(y) * conj(dy/dx) 173 grad_outputs->push_back( 174 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 175 return scope.status(); 176} 177REGISTER_GRADIENT_OP("Sinh", SinhGrad); 178 179Status CoshGrad(const Scope& scope, const Operation& op, 180 const std::vector<Output>& grad_inputs, 181 std::vector<Output>* grad_outputs) { 182 // y = cosh(x) 183 // dy/dx = sinh(x) 184 auto dydx = Sinh(scope, op.input(0)); 185 // grad(x) = grad(y) * conj(dy/dx) 186 grad_outputs->push_back( 187 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 188 return scope.status(); 189} 190REGISTER_GRADIENT_OP("Cosh", CoshGrad); 191 192Status TanhGrad(const Scope& scope, const Operation& op, 193 const std::vector<Output>& grad_inputs, 194 std::vector<Output>* grad_outputs) { 195 // Use the built-in operator. 196 // Note that the built-in operator does not return the conjugate of 197 // the gradient. 198 auto grad = grad_inputs[0]; 199 // Optimization to avoid calculating conj(y) until the gradient is 200 // evaluated. 201 Scope grad_scope = scope.WithControlDependencies(grad); 202 auto y = ConjugateHelper(grad_scope, op.output(0)); 203 grad_outputs->push_back(internal::TanhGrad(scope, y, grad)); 204 return scope.status(); 205} 206REGISTER_GRADIENT_OP("Tanh", TanhGrad); 207 208Status AsinhGrad(const Scope& scope, const Operation& op, 209 const std::vector<Output>& grad_inputs, 210 std::vector<Output>* grad_outputs) { 211 // y = asinh(x) 212 // dy/dx = 1 / cosh(y) 213 auto dydx = Reciprocal(scope, Cosh(scope, op.output(0))); 214 // grad(x) = grad(y) * conj(dy/dx) 215 grad_outputs->push_back( 216 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 217 return scope.status(); 218} 219REGISTER_GRADIENT_OP("Asinh", AsinhGrad); 220 221Status AcoshGrad(const Scope& scope, const Operation& op, 222 const std::vector<Output>& grad_inputs, 223 std::vector<Output>* grad_outputs) { 224 // y = acosh(x) 225 // dy/dx = 1 / sinh(y) 226 auto dydx = Reciprocal(scope, Sinh(scope, op.output(0))); 227 // grad(x) = grad(y) * conj(dy/dx) 228 grad_outputs->push_back( 229 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 230 return scope.status(); 231} 232REGISTER_GRADIENT_OP("Acosh", AcoshGrad); 233 234Status AtanhGrad(const Scope& scope, const Operation& op, 235 const std::vector<Output>& grad_inputs, 236 std::vector<Output>* grad_outputs) { 237 // y = atanh(x) 238 // dy/dx = 1 / (1 - x^2) 239 auto one = Cast(scope, Const(scope, 1.0), op.input(0).type()); 240 auto dydx = Reciprocal(scope, Sub(scope, one, Square(scope, op.input(0)))); 241 // grad(x) = grad(y) * conj(dy/dx) 242 grad_outputs->push_back( 243 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 244 return scope.status(); 245} 246REGISTER_GRADIENT_OP("Atanh", AtanhGrad); 247 248Status SigmoidGrad(const Scope& scope, const Operation& op, 249 const std::vector<Output>& grad_inputs, 250 std::vector<Output>* grad_outputs) { 251 // Use the built-in operator. 252 // Note that the built-in operator does not return the conjugate of 253 // the gradient. 254 auto grad = grad_inputs[0]; 255 // Optimization to avoid calculating conj(y) until the gradient is 256 // evaluated. 257 Scope grad_scope = scope.WithControlDependencies(grad); 258 auto y = ConjugateHelper(grad_scope, op.output(0)); 259 grad_outputs->push_back(internal::SigmoidGrad(scope, y, grad)); 260 return scope.status(); 261} 262REGISTER_GRADIENT_OP("Sigmoid", SigmoidGrad); 263 264Status SignGrad(const Scope& scope, const Operation& op, 265 const std::vector<Output>& grad_inputs, 266 std::vector<Output>* grad_outputs) { 267 auto shape = Shape(scope, op.input(0)); 268 auto zero = Cast(scope, Const(scope, 0.0), op.input(0).type()); 269 auto dx = Fill(scope, shape, zero); 270 grad_outputs->push_back(dx); 271 return scope.status(); 272} 273REGISTER_GRADIENT_OP("Sign", SignGrad); 274 275Status SinGrad(const Scope& scope, const Operation& op, 276 const std::vector<Output>& grad_inputs, 277 std::vector<Output>* grad_outputs) { 278 // y = sin(x) 279 // dy/dx = cos(x) 280 auto dydx = Cos(scope, op.input(0)); 281 // grad(x) = grad(y) * conj(dy/dx) 282 grad_outputs->push_back( 283 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 284 return scope.status(); 285} 286REGISTER_GRADIENT_OP("Sin", SinGrad); 287 288Status CosGrad(const Scope& scope, const Operation& op, 289 const std::vector<Output>& grad_inputs, 290 std::vector<Output>* grad_outputs) { 291 // y = cos(x) 292 // dy/dx = -sin(x) 293 auto dydx = Neg(scope, Sin(scope, op.input(0))); 294 // grad(x) = grad(y) * conj(dy/dx) 295 grad_outputs->push_back( 296 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 297 return scope.status(); 298} 299REGISTER_GRADIENT_OP("Cos", CosGrad); 300 301Status AsinGrad(const Scope& scope, const Operation& op, 302 const std::vector<Output>& grad_inputs, 303 std::vector<Output>* grad_outputs) { 304 // y = asin(x) 305 // dy/dx = 1 / sqrt(1 - x^2) 306 auto x2 = Square(scope, op.input(0)); 307 auto one = Cast(scope, Const(scope, 1.0), op.input(0).type()); 308 auto dydx = Reciprocal(scope, Sqrt(scope, Sub(scope, one, x2))); 309 // grad(x) = grad(y) * conj(dy/dx) 310 auto dx = Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)); 311 grad_outputs->push_back(dx); 312 return scope.status(); 313} 314REGISTER_GRADIENT_OP("Asin", AsinGrad); 315 316Status AcosGrad(const Scope& scope, const Operation& op, 317 const std::vector<Output>& grad_inputs, 318 std::vector<Output>* grad_outputs) { 319 // y = acos(x) 320 // dy/dx = - 1 / (1 - x * x)^1/2 321 // dx = dy * (- 1 / (1 - x * x)^1/2) 322 auto x2 = Square(scope, op.input(0)); 323 auto one = Cast(scope, Const(scope, 1.0), op.input(0).type()); 324 auto dydx = Neg(scope, Reciprocal(scope, Sqrt(scope, Sub(scope, one, x2)))); 325 auto dx = Mul(scope, grad_inputs[0], dydx); 326 grad_outputs->push_back(dx); 327 return scope.status(); 328} 329REGISTER_GRADIENT_OP("Acos", AcosGrad); 330 331Status TanGrad(const Scope& scope, const Operation& op, 332 const std::vector<Output>& grad_inputs, 333 std::vector<Output>* grad_outputs) { 334 // y = tan(x) 335 // dy/dx = sec(x)^2 = 1 / cos(x)^2 336 auto dydx = Square(scope, Reciprocal(scope, Cos(scope, op.input(0)))); 337 // grad(x) = grad(y) * conj(dy/dx) 338 auto dx = Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)); 339 grad_outputs->push_back(dx); 340 return scope.status(); 341} 342REGISTER_GRADIENT_OP("Tan", TanGrad); 343 344Status AtanGrad(const Scope& scope, const Operation& op, 345 const std::vector<Output>& grad_inputs, 346 std::vector<Output>* grad_outputs) { 347 // y = arctan(x) 348 // dy/dx = 1 / (1 + x^2) 349 // dx = dy * (1 / (1 + x^2) 350 auto one = Cast(scope, Const(scope, 1.0), op.input(0).type()); 351 auto dydx = Reciprocal(scope, Add(scope, one, Square(scope, op.input(0)))); 352 auto dx = Mul(scope, grad_inputs[0], dydx); 353 grad_outputs->push_back(dx); 354 return scope.status(); 355} 356REGISTER_GRADIENT_OP("Atan", AtanGrad); 357 358// BinaryGradCommon handles the setup for binary ops that broadcast 359// their inputs. 360Status BinaryGradCommon(const Scope& scope, const Operation& op, 361 std::vector<Output>* grad_outputs, const Output& gx_1, 362 const Output& gx_2) { 363 auto sx_1 = Shape(scope, op.input(0)); 364 auto sx_2 = Shape(scope, op.input(1)); 365 auto rx = internal::BroadcastGradientArgs(scope, sx_1, sx_2); 366 auto dx_1 = Reshape(scope, Sum(scope, gx_1, rx.r0), sx_1); 367 auto dx_2 = Reshape(scope, Sum(scope, gx_2, rx.r1), sx_2); 368 grad_outputs->push_back(dx_1); 369 grad_outputs->push_back(dx_2); 370 return scope.status(); 371} 372 373Status AddGrad(const Scope& scope, const Operation& op, 374 const std::vector<Output>& grad_inputs, 375 std::vector<Output>* grad_outputs) { 376 // y = x_1 + x_2 377 // dy/dx_1 = dy/dx_2 = 1 378 auto gx_1 = Identity(scope, grad_inputs[0]); 379 auto gx_2 = Identity(scope, grad_inputs[0]); 380 return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2); 381} 382REGISTER_GRADIENT_OP("Add", AddGrad); 383 384Status SubGrad(const Scope& scope, const Operation& op, 385 const std::vector<Output>& grad_inputs, 386 std::vector<Output>* grad_outputs) { 387 // y = x_1 - x_2 388 // dy/dx_1 = 1 389 // dy/dx_2 = -1 390 auto gx_1 = Identity(scope, grad_inputs[0]); 391 auto gx_2 = Neg(scope, grad_inputs[0]); 392 return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2); 393} 394REGISTER_GRADIENT_OP("Sub", SubGrad); 395 396Status MulGrad(const Scope& scope, const Operation& op, 397 const std::vector<Output>& grad_inputs, 398 std::vector<Output>* grad_outputs) { 399 auto x_1 = ConjugateHelper(scope, op.input(0)); 400 auto x_2 = ConjugateHelper(scope, op.input(1)); 401 // y = x_1 * x_2 402 // dy/dx_1 = x_2 403 // dy/dx_2 = x_1 404 auto gx_1 = Mul(scope, grad_inputs[0], x_2); 405 auto gx_2 = Mul(scope, grad_inputs[0], x_1); 406 return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2); 407} 408REGISTER_GRADIENT_OP("Mul", MulGrad); 409 410Status DivGrad(const Scope& scope, const Operation& op, 411 const std::vector<Output>& grad_inputs, 412 std::vector<Output>* grad_outputs) { 413 auto x_1 = ConjugateHelper(scope, op.input(0)); 414 auto x_2 = ConjugateHelper(scope, op.input(1)); 415 // y = x_1 / x_2 416 // dy/dx_1 = 1/x_2 417 // dy/dx_2 = -x_1/x_2^2 418 auto gx_1 = Div(scope, grad_inputs[0], x_2); 419 auto gx_2 = Mul(scope, grad_inputs[0], 420 Div(scope, Div(scope, Neg(scope, x_1), x_2), x_2)); 421 return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2); 422} 423REGISTER_GRADIENT_OP("Div", DivGrad); 424 425Status RealDivGrad(const Scope& scope, const Operation& op, 426 const std::vector<Output>& grad_inputs, 427 std::vector<Output>* grad_outputs) { 428 auto x_1 = ConjugateHelper(scope, op.input(0)); 429 auto x_2 = ConjugateHelper(scope, op.input(1)); 430 // y = x_1 / x_2 431 // dy/dx_1 = 1/x_2 432 // dy/dx_2 = -x_1/x_2^2 433 auto gx_1 = RealDiv(scope, grad_inputs[0], x_2); 434 auto gx_2 = Mul(scope, grad_inputs[0], 435 RealDiv(scope, RealDiv(scope, Neg(scope, x_1), x_2), x_2)); 436 return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2); 437} 438REGISTER_GRADIENT_OP("RealDiv", RealDivGrad); 439 440Status SquaredDifferenceGrad(const Scope& scope, const Operation& op, 441 const std::vector<Output>& grad_inputs, 442 std::vector<Output>* grad_outputs) { 443 auto x_1 = ConjugateHelper(scope, op.input(0)); 444 auto x_2 = ConjugateHelper(scope, op.input(1)); 445 // y = (x_1 - x_2)^2 446 // dy/dx_1 = 2 * (x_1 - x_2) 447 // dy/dx_2 = -2 * (x_1 - x_2) 448 auto two = Cast(scope, Const(scope, 2), grad_inputs[0].type()); 449 auto gx_1 = Mul(scope, grad_inputs[0], Mul(scope, two, Sub(scope, x_1, x_2))); 450 auto gx_2 = Neg(scope, gx_1); 451 return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2); 452} 453REGISTER_GRADIENT_OP("SquaredDifference", SquaredDifferenceGrad); 454 455Status AddNGrad(const Scope& scope, const Operation& op, 456 const std::vector<Output>& grad_inputs, 457 std::vector<Output>* grad_outputs) { 458 // AddN doesn't support broadcasting, so all the inputs must be the 459 // same shape. 460 // Note: 461 // dy/dx_k = d(x_1 + x_2 + ... + x_n)/dx_k = 1 for all x_k 462 // hence dx_k = dy for all x_k 463 // So the gradient for AddN just transfers the incoming gradient to 464 // all outgoing gradients. 465 auto incoming = Identity(scope, grad_inputs[0]); 466 for (int32 i = 0; i < op.num_inputs(); ++i) { 467 grad_outputs->push_back(incoming); 468 } 469 return scope.status(); 470} 471REGISTER_GRADIENT_OP("AddN", AddNGrad); 472 473// MaximumMinimumGradCommon adds shared ops to calculate gradients for 474// the binary Maximum and Minimum ops. 475Status MaximumMinimumGradCommon(const Scope& scope, const Operation& op, 476 const std::vector<Output>& grad_inputs, 477 std::vector<Output>* grad_outputs, 478 const Output& comparator) { 479 // comparator is a boolean tensor, with 480 // y = x_1 at points where comparator is true, and x_2 otherwise 481 // Therefore 482 // dy/dx_1 = 1 where comparator is true, and 0 otherwise. 483 // dy/dx_2 = 0 where comparator is true, and 1 otherwise. 484 auto grad = grad_inputs[0]; 485 auto zeros = ZerosLike(scope, grad); 486 auto gx_1 = Where3(scope, comparator, grad, zeros); 487 auto gx_2 = Where3(scope, LogicalNot(scope, comparator), grad, zeros); 488 return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2); 489} 490 491Status MaximumGrad(const Scope& scope, const Operation& op, 492 const std::vector<Output>& grad_inputs, 493 std::vector<Output>* grad_outputs) { 494 auto comparator = GreaterEqual(scope, op.input(0), op.input(1)); 495 return MaximumMinimumGradCommon(scope, op, grad_inputs, grad_outputs, 496 comparator); 497} 498REGISTER_GRADIENT_OP("Maximum", MaximumGrad); 499 500Status MinimumGrad(const Scope& scope, const Operation& op, 501 const std::vector<Output>& grad_inputs, 502 std::vector<Output>* grad_outputs) { 503 auto comparator = LessEqual(scope, op.input(0), op.input(1)); 504 return MaximumMinimumGradCommon(scope, op, grad_inputs, grad_outputs, 505 comparator); 506} 507REGISTER_GRADIENT_OP("Minimum", MinimumGrad); 508 509Status RealGrad(const Scope& scope, const Operation& op, 510 const std::vector<Output>& grad_inputs, 511 std::vector<Output>* grad_outputs) { 512 auto zero = Cast(scope, Const(scope, 0.0), op.output(0).type()); 513 auto dx = Complex(scope, grad_inputs[0], zero); 514 grad_outputs->push_back(dx); 515 return scope.status(); 516} 517REGISTER_GRADIENT_OP("Real", RealGrad); 518 519Status ImagGrad(const Scope& scope, const Operation& op, 520 const std::vector<Output>& grad_inputs, 521 std::vector<Output>* grad_outputs) { 522 auto zero = Cast(scope, Const(scope, 0.0), op.output(0).type()); 523 auto dx = Complex(scope, zero, grad_inputs[0]); 524 grad_outputs->push_back(dx); 525 return scope.status(); 526} 527REGISTER_GRADIENT_OP("Imag", ImagGrad); 528 529Status AngleGrad(const Scope& scope, const Operation& op, 530 const std::vector<Output>& grad_inputs, 531 std::vector<Output>* grad_outputs) { 532 // y = Angle(x) 533 // dx = -dy / (Im(x) + iRe(x)) = -dy * z 534 auto re = Real(scope, op.input(0)); 535 auto im = Imag(scope, op.input(0)); 536 auto z_inv = Reciprocal(scope, Complex(scope, im, re)); 537 auto zero = Cast(scope, Const(scope, 0), grad_inputs[0].type()); 538 auto grad = Complex(scope, grad_inputs[0], zero); 539 auto dx = Neg(scope, Mul(scope, grad, z_inv)); 540 grad_outputs->push_back(dx); 541 return scope.status(); 542} 543REGISTER_GRADIENT_OP("Angle", AngleGrad); 544 545Status ConjGrad(const Scope& scope, const Operation& op, 546 const std::vector<Output>& grad_inputs, 547 std::vector<Output>* grad_outputs) { 548 grad_outputs->push_back(Conj(scope, grad_inputs[0])); 549 return scope.status(); 550} 551REGISTER_GRADIENT_OP("Conj", ConjGrad); 552 553// Integer division x / y, assuming x and y >=0, but treats x/0 = x 554Output SafeDivHelper(const Scope& scope, const Output& x, const Output& y) { 555 return Div(scope, x, Maximum(scope, y, Const(scope, 1))); 556} 557 558// Helper function for reduction ops. 559// 560// input_shape: 1-D Tensor, the shape of the Tensor being reduced. 561// axes: 1-D Tensor, the reduction axes. 562// Note that the reduction indices are in the range 563// -rank(input_shape), rank(input_shape) 564// returns a 1-D Tensor, the output shape as if keep_dims were set to True. 565Output ReducedShapeHelper(const Scope& scope, const Output& input_shape, 566 const Output& reduction_axes) { 567 auto zero = Const(scope, 0); 568 auto one = Const(scope, 1); 569 570 // Running example in comments 571 // input_shape = [2, 3, 5, 7] 572 // axes = [1, 2] 573 // The result (a shape after a reduction with keep_dims=True) 574 // [2, 1, 1, 7] 575 // 576 // We can treat each entry in axes as an index into input_shape that 577 // should be replaced by 1. 578 // We use DynamicStitch to do this. 579 580 // input_rank = 4 581 auto input_rank = Size(scope, input_shape); 582 583 // Normalize any negative indices in the reduction_axes to positive 584 // values. 585 auto axes = Mod(scope, Add(scope, reduction_axes, input_rank), input_rank); 586 587 // This [0..input_rank) range of integers is used in DynamicStitch to 588 // first copy input_shape to the result. 589 // input_rank_range = [0, 1, 2, 3] 590 auto input_rank_range = Range(scope, zero, input_rank, one); 591 592 // A 1-filled tensor with the same shape as axes. DynamicStitch will 593 // merge these 1s (using axes for indices) to the correct 594 // position in the result. 595 // axes_ones = [1, 1] 596 auto axes_ones = OnesLike(scope, axes); 597 598 // using DynamicStitch: 599 // indices = { input_rank_range, axes } 600 // = { [0, 1, 2, 3], [1, 2] } 601 // data = { input_shape, axes_ones } 602 // = { [2, 3, 5, 7], [1, 1] } 603 // The input_rank_range entry in indices first replicates the 604 // input_shape to the result. 605 // The axes entry in indices then moves a 1 to each of its entries, 606 // resulting in 607 // [2, 1, 1, 7] 608 std::vector<Output> indices = {input_rank_range, axes}; 609 std::vector<Output> data = {input_shape, axes_ones}; 610 return DynamicStitch(scope, indices, data); 611} 612 613// SumGradHelper returns the gradient for the Sum operator, and is used 614// by SumGrad and MeanGrad. 615Output SumGradHelper(const Scope& scope, const Operation& op, 616 const std::vector<Output>& grad_inputs) { 617 // The partial derivative for any input along a "reduced" dimension 618 // is just 1, so we only need replicate the output gradient on such a 619 // dimension to its "expanded" shape. 620 // Running example: 621 // input is 622 // [[a, b, c], 623 // [d, e, f]] 624 // reduction_indices = [1] 625 // Sum = [a + b + c, d + e + f] 626 // if the gradient is [g1, g2] 627 // We want the propagated gradient to be 628 // [[g1, g1, g1], 629 // [g2, g2, g2]] 630 631 // input_shape = [2, 3] 632 auto input_shape = Shape(scope, op.input(0)); 633 634 // output_shape_kept_dims = [2, 1] 635 auto output_shape_kept_dims = 636 ReducedShapeHelper(scope, input_shape, op.input(1)); 637 638 // This step "flips" any 1s with values from the input_shape, and 639 // replaces remaining entries with 1. This creates a shape that 640 // shows how much each dimension in the incoming gradient should be 641 // replicated. 642 // tile_scaling = [1, 3] 643 auto tile_scaling = SafeDivHelper(scope, input_shape, output_shape_kept_dims); 644 645 // grad = [[g1], [g2]] 646 auto grad = Reshape(scope, grad_inputs[0], output_shape_kept_dims); 647 648 // tile(grad, tile_scaling) = [[g1, g1, g1], [g2, g2, g2]] 649 return Tile(scope, grad, tile_scaling); 650} 651 652Status SumGrad(const Scope& scope, const Operation& op, 653 const std::vector<Output>& grad_inputs, 654 std::vector<Output>* grad_outputs) { 655 grad_outputs->push_back(SumGradHelper(scope, op, grad_inputs)); 656 657 // Stop propagation along reduction_indices 658 grad_outputs->push_back(NoGradient()); 659 return scope.status(); 660} 661REGISTER_GRADIENT_OP("Sum", SumGrad); 662 663Status MeanGrad(const Scope& scope, const Operation& op, 664 const std::vector<Output>& grad_inputs, 665 std::vector<Output>* grad_outputs) { 666 // The Mean gradient is just like the Sum gradient, except that 667 // all gradients are also divided by the size of reduced groups. 668 auto sum_grad = SumGradHelper(scope, op, grad_inputs); 669 670 // The product of all entries in a tensor's shape is the total 671 // number of entries in the tensor. This step calculates 672 // n_input_entries/n_output_entries 673 // = group_size 674 auto input_shape = Shape(scope, op.input(0)); 675 auto output_shape = Shape(scope, op.output(0)); 676 auto zero = Const(scope, 0); 677 auto group_size = SafeDivHelper(scope, Prod(scope, input_shape, zero), 678 Prod(scope, output_shape, zero)); 679 680 // propagate sum_grad/group_size 681 grad_outputs->push_back( 682 Div(scope, sum_grad, Cast(scope, group_size, sum_grad.type()))); 683 684 // Stop propagation along reduction_indices 685 grad_outputs->push_back(NoGradient()); 686 return scope.status(); 687} 688REGISTER_GRADIENT_OP("Mean", MeanGrad); 689 690// MatMulGrad helper function used to compute two MatMul operations 691// based on input matrix transposition combinations. 692Status MatMulGradHelper(const Scope& scope, const bool is_batch, 693 const Output& x0, const bool adj_x0, const Output& x1, 694 const bool adj_x1, const Output& y0, const bool adj_y0, 695 const Output& y1, const bool adj_y1, 696 std::vector<Output>* grad_outputs) { 697 if (is_batch == false) { 698 auto dx = 699 MatMul(scope, x0, x1, MatMul::TransposeA(adj_x0).TransposeB(adj_x1)); 700 grad_outputs->push_back(dx); 701 auto dy = 702 MatMul(scope, y0, y1, MatMul::TransposeA(adj_y0).TransposeB(adj_y1)); 703 grad_outputs->push_back(dy); 704 } else { 705 auto dx = 706 BatchMatMul(scope, x0, x1, BatchMatMul::AdjX(adj_x0).AdjY(adj_x1)); 707 grad_outputs->push_back(dx); 708 auto dy = 709 BatchMatMul(scope, y0, y1, BatchMatMul::AdjX(adj_y0).AdjY(adj_y1)); 710 grad_outputs->push_back(dy); 711 } 712 return scope.status(); 713} 714 715// MatMulGrad common used to read and check node attr state, and determine 716// proper MatMul products for gradients based on input matrix transposition 717// combinations. 718// TODO(andydavis) Re-use this function for BatchMatMulGrad. 719Status MatMulGradCommon(const Scope& scope, const Operation& op, 720 const bool is_batch, 721 const std::vector<Output>& grad_inputs, 722 const string& attr_adj_x, const string& attr_adj_y, 723 std::vector<Output>* grad_outputs) { 724 DataType dtype; 725 TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->attrs(), "T", &dtype)); 726 if (dtype == DT_COMPLEX64 || dtype == DT_COMPLEX128) { 727 return errors::Unimplemented( 728 "MatMul gradient for complex data type is not supported yet."); 729 } 730 731 bool ta; 732 bool tb; 733 TF_RETURN_IF_ERROR( 734 GetNodeAttr(op.output(0).node()->attrs(), attr_adj_x, &ta)); 735 TF_RETURN_IF_ERROR( 736 GetNodeAttr(op.output(0).node()->attrs(), attr_adj_y, &tb)); 737 738 if (!ta && !tb) { 739 return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, op.input(1), 740 true, op.input(0), true, grad_inputs[0], false, 741 grad_outputs); 742 } else if (!ta && tb) { 743 return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, op.input(1), 744 false, grad_inputs[0], true, op.input(0), false, 745 grad_outputs); 746 } else if (ta && !tb) { 747 return MatMulGradHelper(scope, is_batch, op.input(1), false, grad_inputs[0], 748 true, op.input(0), false, grad_inputs[0], false, 749 grad_outputs); 750 } 751 return MatMulGradHelper(scope, is_batch, op.input(1), true, grad_inputs[0], 752 true, grad_inputs[0], true, op.input(0), true, 753 grad_outputs); 754} 755 756Status MatMulGrad(const Scope& scope, const Operation& op, 757 const std::vector<Output>& grad_inputs, 758 std::vector<Output>* grad_outputs) { 759 return MatMulGradCommon(scope, op, false, grad_inputs, "transpose_a", 760 "transpose_b", grad_outputs); 761} 762REGISTER_GRADIENT_OP("MatMul", MatMulGrad); 763 764Status BatchMatMulGrad(const Scope& scope, const Operation& op, 765 const std::vector<Output>& grad_inputs, 766 std::vector<Output>* grad_outputs) { 767 return MatMulGradCommon(scope, op, true, grad_inputs, "adj_x", "adj_y", 768 grad_outputs); 769} 770REGISTER_GRADIENT_OP("BatchMatMul", BatchMatMulGrad); 771 772} // anonymous namespace 773} // namespace ops 774} // namespace tensorflow 775