math_grad.cc revision 355e25ebcab64e833dfc987638c3e6c79d838266
1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#define _USE_MATH_DEFINES 17#include <cmath> 18 19#include "tensorflow/cc/ops/array_ops_internal.h" 20#include "tensorflow/cc/ops/math_ops_internal.h" 21#include "tensorflow/cc/ops/standard_ops.h" 22 23#include "tensorflow/cc/framework/grad_op_registry.h" 24#include "tensorflow/cc/framework/gradients.h" 25 26namespace tensorflow { 27namespace ops { 28namespace { 29 30// Logical operations have no gradients. 31REGISTER_NO_GRADIENT_OP("Less"); 32REGISTER_NO_GRADIENT_OP("LessEqual"); 33REGISTER_NO_GRADIENT_OP("Greater"); 34REGISTER_NO_GRADIENT_OP("GreaterEqual"); 35REGISTER_NO_GRADIENT_OP("Equal"); 36REGISTER_NO_GRADIENT_OP("ApproximateEqual"); 37REGISTER_NO_GRADIENT_OP("NotEqual"); 38REGISTER_NO_GRADIENT_OP("LogicalAnd"); 39REGISTER_NO_GRADIENT_OP("LogicalOr"); 40REGISTER_NO_GRADIENT_OP("LogicalNot"); 41 42// Conjugate helper function returns the conjugate of an Output if it 43// is complex valued. 44Output ConjugateHelper(const Scope& scope, const Output& out) { 45 DataType dtype = out.type(); 46 if (dtype == DT_COMPLEX64 || dtype == DT_COMPLEX128) { 47 return Conj(scope, out); 48 } else { 49 return out; 50 } 51} 52 53// TODO(andydavis) Add control dependencies to gradient functions (as needed). 54 55Status AbsGrad(const Scope& scope, const Operation& op, 56 const std::vector<Output>& grad_inputs, 57 std::vector<Output>* grad_outputs) { 58 // dx = dy * sign(x) 59 grad_outputs->push_back(Mul(scope, grad_inputs[0], Sign(scope, op.input(0)))); 60 return scope.status(); 61} 62REGISTER_GRADIENT_OP("Abs", AbsGrad); 63 64Status NegGrad(const Scope& scope, const Operation& op, 65 const std::vector<Output>& grad_inputs, 66 std::vector<Output>* grad_outputs) { 67 // dx = -dy; 68 grad_outputs->push_back(Neg(scope, grad_inputs[0])); 69 return scope.status(); 70} 71REGISTER_GRADIENT_OP("Neg", NegGrad); 72 73Status InvGrad(const Scope& scope, const Operation& op, 74 const std::vector<Output>& grad_inputs, 75 std::vector<Output>* grad_outputs) { 76 // Use the built-in operator. 77 grad_outputs->push_back( 78 internal::ReciprocalGrad(scope, op.output(0), grad_inputs[0])); 79 return scope.status(); 80} 81REGISTER_GRADIENT_OP("Inv", InvGrad); 82REGISTER_GRADIENT_OP("Reciprocal", InvGrad); 83 84Status SquareGrad(const Scope& scope, const Operation& op, 85 const std::vector<Output>& grad_inputs, 86 std::vector<Output>* grad_outputs) { 87 // dy/dx = (2 * x) 88 auto two = Cast(scope, Const(scope, 2), op.input(0).type()); 89 auto dydx = Mul(scope, two, op.input(0)); 90 // grad(x) = grad(y) * conj(dy/dx) 91 grad_outputs->push_back( 92 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 93 return scope.status(); 94} 95REGISTER_GRADIENT_OP("Square", SquareGrad); 96 97Status SqrtGrad(const Scope& scope, const Operation& op, 98 const std::vector<Output>& grad_inputs, 99 std::vector<Output>* grad_outputs) { 100 // Use the built-in operator. 101 grad_outputs->push_back( 102 internal::SqrtGrad(scope, op.output(0), grad_inputs[0])); 103 return scope.status(); 104} 105REGISTER_GRADIENT_OP("Sqrt", SqrtGrad); 106 107Status RsqrtGrad(const Scope& scope, const Operation& op, 108 const std::vector<Output>& grad_inputs, 109 std::vector<Output>* grad_outputs) { 110 // Use the built-in operator. 111 grad_outputs->push_back( 112 internal::RsqrtGrad(scope, op.output(0), grad_inputs[0])); 113 return scope.status(); 114} 115REGISTER_GRADIENT_OP("Rsqrt", RsqrtGrad); 116 117Status ExpGrad(const Scope& scope, const Operation& op, 118 const std::vector<Output>& grad_inputs, 119 std::vector<Output>* grad_outputs) { 120 // dy/dx = exp(x) = y 121 // grad(x) = grad(y) * conj(dy/dx) 122 // = grad(y) * conj(y) 123 grad_outputs->push_back( 124 Mul(scope, grad_inputs[0], ConjugateHelper(scope, op.output(0)))); 125 return scope.status(); 126} 127REGISTER_GRADIENT_OP("Exp", ExpGrad); 128 129Status Expm1Grad(const Scope& scope, const Operation& op, 130 const std::vector<Output>& grad_inputs, 131 std::vector<Output>* grad_outputs) { 132 // y = expm1(x) 133 // dy/dx = exp(x) 134 auto dydx = Exp(scope, op.input(0)); 135 // grad(x) = grad(y) * conj(dy/dx) 136 grad_outputs->push_back( 137 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 138 return scope.status(); 139} 140REGISTER_GRADIENT_OP("Expm1", Expm1Grad); 141 142Status LogGrad(const Scope& scope, const Operation& op, 143 const std::vector<Output>& grad_inputs, 144 std::vector<Output>* grad_outputs) { 145 // y = log(x) 146 // dy/dx = 1 / x 147 auto dydx = Reciprocal(scope, op.input(0)); 148 // grad(x) = grad(y) * conj(dy/dx) 149 grad_outputs->push_back( 150 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 151 return scope.status(); 152} 153REGISTER_GRADIENT_OP("Log", LogGrad); 154 155Status Log1pGrad(const Scope& scope, const Operation& op, 156 const std::vector<Output>& grad_inputs, 157 std::vector<Output>* grad_outputs) { 158 // y = log1p(x) 159 // dy/dx = 1 / (1 + x) 160 auto one = Cast(scope, Const(scope, 1.0), op.input(0).type()); 161 auto dydx = Reciprocal(scope, Add(scope, one, op.input(0))); 162 // grad(x) = grad(y) * conj(dy/dx) 163 grad_outputs->push_back( 164 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 165 return scope.status(); 166} 167REGISTER_GRADIENT_OP("Log1p", Log1pGrad); 168 169Status SinhGrad(const Scope& scope, const Operation& op, 170 const std::vector<Output>& grad_inputs, 171 std::vector<Output>* grad_outputs) { 172 // y = sinh(x) 173 // dy/dx = cosh(x) 174 auto dydx = Cosh(scope, op.input(0)); 175 // grad(x) = grad(y) * conj(dy/dx) 176 grad_outputs->push_back( 177 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 178 return scope.status(); 179} 180REGISTER_GRADIENT_OP("Sinh", SinhGrad); 181 182Status CoshGrad(const Scope& scope, const Operation& op, 183 const std::vector<Output>& grad_inputs, 184 std::vector<Output>* grad_outputs) { 185 // y = cosh(x) 186 // dy/dx = sinh(x) 187 auto dydx = Sinh(scope, op.input(0)); 188 // grad(x) = grad(y) * conj(dy/dx) 189 grad_outputs->push_back( 190 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 191 return scope.status(); 192} 193REGISTER_GRADIENT_OP("Cosh", CoshGrad); 194 195Status TanhGrad(const Scope& scope, const Operation& op, 196 const std::vector<Output>& grad_inputs, 197 std::vector<Output>* grad_outputs) { 198 // Use the built-in operator. 199 // Note that the built-in operator does not return the conjugate of 200 // the gradient. 201 auto grad = grad_inputs[0]; 202 // Optimization to avoid calculating conj(y) until the gradient is 203 // evaluated. 204 Scope grad_scope = scope.WithControlDependencies(grad); 205 auto y = ConjugateHelper(grad_scope, op.output(0)); 206 grad_outputs->push_back(internal::TanhGrad(grad_scope, y, grad)); 207 return grad_scope.status(); 208} 209REGISTER_GRADIENT_OP("Tanh", TanhGrad); 210 211Status AsinhGrad(const Scope& scope, const Operation& op, 212 const std::vector<Output>& grad_inputs, 213 std::vector<Output>* grad_outputs) { 214 // y = asinh(x) 215 // dy/dx = 1 / cosh(y) 216 auto dydx = Reciprocal(scope, Cosh(scope, op.output(0))); 217 // grad(x) = grad(y) * conj(dy/dx) 218 grad_outputs->push_back( 219 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 220 return scope.status(); 221} 222REGISTER_GRADIENT_OP("Asinh", AsinhGrad); 223 224Status AcoshGrad(const Scope& scope, const Operation& op, 225 const std::vector<Output>& grad_inputs, 226 std::vector<Output>* grad_outputs) { 227 // y = acosh(x) 228 // dy/dx = 1 / sinh(y) 229 auto dydx = Reciprocal(scope, Sinh(scope, op.output(0))); 230 // grad(x) = grad(y) * conj(dy/dx) 231 grad_outputs->push_back( 232 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 233 return scope.status(); 234} 235REGISTER_GRADIENT_OP("Acosh", AcoshGrad); 236 237Status AtanhGrad(const Scope& scope, const Operation& op, 238 const std::vector<Output>& grad_inputs, 239 std::vector<Output>* grad_outputs) { 240 // y = atanh(x) 241 // dy/dx = 1 / (1 - x^2) 242 auto one = Cast(scope, Const(scope, 1.0), op.input(0).type()); 243 auto dydx = Reciprocal(scope, Sub(scope, one, Square(scope, op.input(0)))); 244 // grad(x) = grad(y) * conj(dy/dx) 245 grad_outputs->push_back( 246 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 247 return scope.status(); 248} 249REGISTER_GRADIENT_OP("Atanh", AtanhGrad); 250 251Status SigmoidGrad(const Scope& scope, const Operation& op, 252 const std::vector<Output>& grad_inputs, 253 std::vector<Output>* grad_outputs) { 254 // Use the built-in operator. 255 // Note that the built-in operator does not return the conjugate of 256 // the gradient. 257 auto grad = grad_inputs[0]; 258 // Optimization to avoid calculating conj(y) until the gradient is 259 // evaluated. 260 Scope grad_scope = scope.WithControlDependencies(grad); 261 auto y = ConjugateHelper(grad_scope, op.output(0)); 262 grad_outputs->push_back(internal::SigmoidGrad(grad_scope, y, grad)); 263 return grad_scope.status(); 264} 265REGISTER_GRADIENT_OP("Sigmoid", SigmoidGrad); 266 267Status SignGrad(const Scope& scope, const Operation& op, 268 const std::vector<Output>& grad_inputs, 269 std::vector<Output>* grad_outputs) { 270 auto shape = Shape(scope, op.input(0)); 271 auto zero = Cast(scope, Const(scope, 0.0), op.input(0).type()); 272 auto dx = Fill(scope, shape, zero); 273 grad_outputs->push_back(dx); 274 return scope.status(); 275} 276REGISTER_GRADIENT_OP("Sign", SignGrad); 277 278Status SinGrad(const Scope& scope, const Operation& op, 279 const std::vector<Output>& grad_inputs, 280 std::vector<Output>* grad_outputs) { 281 // y = sin(x) 282 // dy/dx = cos(x) 283 auto dydx = Cos(scope, op.input(0)); 284 // grad(x) = grad(y) * conj(dy/dx) 285 grad_outputs->push_back( 286 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 287 return scope.status(); 288} 289REGISTER_GRADIENT_OP("Sin", SinGrad); 290 291Status CosGrad(const Scope& scope, const Operation& op, 292 const std::vector<Output>& grad_inputs, 293 std::vector<Output>* grad_outputs) { 294 // y = cos(x) 295 // dy/dx = -sin(x) 296 auto dydx = Neg(scope, Sin(scope, op.input(0))); 297 // grad(x) = grad(y) * conj(dy/dx) 298 grad_outputs->push_back( 299 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 300 return scope.status(); 301} 302REGISTER_GRADIENT_OP("Cos", CosGrad); 303 304Status AsinGrad(const Scope& scope, const Operation& op, 305 const std::vector<Output>& grad_inputs, 306 std::vector<Output>* grad_outputs) { 307 // y = asin(x) 308 // dy/dx = 1 / sqrt(1 - x^2) 309 auto x2 = Square(scope, op.input(0)); 310 auto one = Cast(scope, Const(scope, 1.0), op.input(0).type()); 311 auto dydx = Reciprocal(scope, Sqrt(scope, Sub(scope, one, x2))); 312 // grad(x) = grad(y) * conj(dy/dx) 313 auto dx = Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)); 314 grad_outputs->push_back(dx); 315 return scope.status(); 316} 317REGISTER_GRADIENT_OP("Asin", AsinGrad); 318 319Status AcosGrad(const Scope& scope, const Operation& op, 320 const std::vector<Output>& grad_inputs, 321 std::vector<Output>* grad_outputs) { 322 // y = acos(x) 323 // dy/dx = - 1 / (1 - x * x)^1/2 324 // dx = dy * (- 1 / (1 - x * x)^1/2) 325 auto x2 = Square(scope, op.input(0)); 326 auto one = Cast(scope, Const(scope, 1.0), op.input(0).type()); 327 auto dydx = Neg(scope, Reciprocal(scope, Sqrt(scope, Sub(scope, one, x2)))); 328 auto dx = Mul(scope, grad_inputs[0], dydx); 329 grad_outputs->push_back(dx); 330 return scope.status(); 331} 332REGISTER_GRADIENT_OP("Acos", AcosGrad); 333 334Status TanGrad(const Scope& scope, const Operation& op, 335 const std::vector<Output>& grad_inputs, 336 std::vector<Output>* grad_outputs) { 337 // y = tan(x) 338 // dy/dx = sec(x)^2 = 1 / cos(x)^2 339 auto dydx = Square(scope, Reciprocal(scope, Cos(scope, op.input(0)))); 340 // grad(x) = grad(y) * conj(dy/dx) 341 auto dx = Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)); 342 grad_outputs->push_back(dx); 343 return scope.status(); 344} 345REGISTER_GRADIENT_OP("Tan", TanGrad); 346 347Status AtanGrad(const Scope& scope, const Operation& op, 348 const std::vector<Output>& grad_inputs, 349 std::vector<Output>* grad_outputs) { 350 // y = arctan(x) 351 // dy/dx = 1 / (1 + x^2) 352 // dx = dy * (1 / (1 + x^2) 353 auto one = Cast(scope, Const(scope, 1.0), op.input(0).type()); 354 auto dydx = Reciprocal(scope, Add(scope, one, Square(scope, op.input(0)))); 355 auto dx = Mul(scope, grad_inputs[0], dydx); 356 grad_outputs->push_back(dx); 357 return scope.status(); 358} 359REGISTER_GRADIENT_OP("Atan", AtanGrad); 360 361// BinaryGradCommon handles the setup for binary ops that broadcast 362// their inputs. 363Status BinaryGradCommon(const Scope& scope, const Operation& op, 364 std::vector<Output>* grad_outputs, const Output& gx_1, 365 const Output& gx_2) { 366 auto sx_1 = Shape(scope, op.input(0)); 367 auto sx_2 = Shape(scope, op.input(1)); 368 auto rx = internal::BroadcastGradientArgs(scope, sx_1, sx_2); 369 auto dx_1 = Reshape(scope, Sum(scope, gx_1, rx.r0), sx_1); 370 auto dx_2 = Reshape(scope, Sum(scope, gx_2, rx.r1), sx_2); 371 grad_outputs->push_back(dx_1); 372 grad_outputs->push_back(dx_2); 373 return scope.status(); 374} 375 376Status AddGrad(const Scope& scope, const Operation& op, 377 const std::vector<Output>& grad_inputs, 378 std::vector<Output>* grad_outputs) { 379 // y = x_1 + x_2 380 // dy/dx_1 = dy/dx_2 = 1 381 auto gx_1 = Identity(scope, grad_inputs[0]); 382 auto gx_2 = Identity(scope, grad_inputs[0]); 383 return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2); 384} 385REGISTER_GRADIENT_OP("Add", AddGrad); 386 387Status SubGrad(const Scope& scope, const Operation& op, 388 const std::vector<Output>& grad_inputs, 389 std::vector<Output>* grad_outputs) { 390 // y = x_1 - x_2 391 // dy/dx_1 = 1 392 // dy/dx_2 = -1 393 auto gx_1 = Identity(scope, grad_inputs[0]); 394 auto gx_2 = Neg(scope, grad_inputs[0]); 395 return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2); 396} 397REGISTER_GRADIENT_OP("Sub", SubGrad); 398 399Status MulGrad(const Scope& scope, const Operation& op, 400 const std::vector<Output>& grad_inputs, 401 std::vector<Output>* grad_outputs) { 402 auto x_1 = ConjugateHelper(scope, op.input(0)); 403 auto x_2 = ConjugateHelper(scope, op.input(1)); 404 // y = x_1 * x_2 405 // dy/dx_1 = x_2 406 // dy/dx_2 = x_1 407 auto gx_1 = Mul(scope, grad_inputs[0], x_2); 408 auto gx_2 = Mul(scope, grad_inputs[0], x_1); 409 return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2); 410} 411REGISTER_GRADIENT_OP("Mul", MulGrad); 412 413Status DivGrad(const Scope& scope, const Operation& op, 414 const std::vector<Output>& grad_inputs, 415 std::vector<Output>* grad_outputs) { 416 auto x_1 = ConjugateHelper(scope, op.input(0)); 417 auto x_2 = ConjugateHelper(scope, op.input(1)); 418 // y = x_1 / x_2 419 // dy/dx_1 = 1/x_2 420 // dy/dx_2 = -x_1/x_2^2 421 auto gx_1 = Div(scope, grad_inputs[0], x_2); 422 auto gx_2 = Mul(scope, grad_inputs[0], 423 Div(scope, Div(scope, Neg(scope, x_1), x_2), x_2)); 424 return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2); 425} 426REGISTER_GRADIENT_OP("Div", DivGrad); 427 428Status RealDivGrad(const Scope& scope, const Operation& op, 429 const std::vector<Output>& grad_inputs, 430 std::vector<Output>* grad_outputs) { 431 auto x_1 = ConjugateHelper(scope, op.input(0)); 432 auto x_2 = ConjugateHelper(scope, op.input(1)); 433 // y = x_1 / x_2 434 // dy/dx_1 = 1/x_2 435 // dy/dx_2 = -x_1/x_2^2 436 auto gx_1 = RealDiv(scope, grad_inputs[0], x_2); 437 auto gx_2 = Mul(scope, grad_inputs[0], 438 RealDiv(scope, RealDiv(scope, Neg(scope, x_1), x_2), x_2)); 439 return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2); 440} 441REGISTER_GRADIENT_OP("RealDiv", RealDivGrad); 442 443Status SquaredDifferenceGrad(const Scope& scope, const Operation& op, 444 const std::vector<Output>& grad_inputs, 445 std::vector<Output>* grad_outputs) { 446 auto x_1 = ConjugateHelper(scope, op.input(0)); 447 auto x_2 = ConjugateHelper(scope, op.input(1)); 448 // y = (x_1 - x_2)^2 449 // dy/dx_1 = 2 * (x_1 - x_2) 450 // dy/dx_2 = -2 * (x_1 - x_2) 451 auto two = Cast(scope, Const(scope, 2), grad_inputs[0].type()); 452 auto gx_1 = Mul(scope, grad_inputs[0], Mul(scope, two, Sub(scope, x_1, x_2))); 453 auto gx_2 = Neg(scope, gx_1); 454 return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2); 455} 456REGISTER_GRADIENT_OP("SquaredDifference", SquaredDifferenceGrad); 457 458Status AddNGrad(const Scope& scope, const Operation& op, 459 const std::vector<Output>& grad_inputs, 460 std::vector<Output>* grad_outputs) { 461 // AddN doesn't support broadcasting, so all the inputs must be the 462 // same shape. 463 // Note: 464 // dy/dx_k = d(x_1 + x_2 + ... + x_n)/dx_k = 1 for all x_k 465 // hence dx_k = dy for all x_k 466 // So the gradient for AddN just transfers the incoming gradient to 467 // all outgoing gradients. 468 auto incoming = Identity(scope, grad_inputs[0]); 469 for (int32 i = 0; i < op.num_inputs(); ++i) { 470 grad_outputs->push_back(incoming); 471 } 472 return scope.status(); 473} 474REGISTER_GRADIENT_OP("AddN", AddNGrad); 475 476// MaximumMinimumGradCommon adds shared ops to calculate gradients for 477// the binary Maximum and Minimum ops. 478Status MaximumMinimumGradCommon(const Scope& scope, const Operation& op, 479 const std::vector<Output>& grad_inputs, 480 std::vector<Output>* grad_outputs, 481 const Output& comparator) { 482 // comparator is a boolean tensor, with 483 // y = x_1 at points where comparator is true, and x_2 otherwise 484 // Therefore 485 // dy/dx_1 = 1 where comparator is true, and 0 otherwise. 486 // dy/dx_2 = 0 where comparator is true, and 1 otherwise. 487 auto grad = grad_inputs[0]; 488 auto zeros = ZerosLike(scope, grad); 489 auto gx_1 = Where3(scope, comparator, grad, zeros); 490 auto gx_2 = Where3(scope, comparator, zeros, grad); 491 return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2); 492} 493 494Status MaximumGrad(const Scope& scope, const Operation& op, 495 const std::vector<Output>& grad_inputs, 496 std::vector<Output>* grad_outputs) { 497 auto comparator = GreaterEqual(scope, op.input(0), op.input(1)); 498 return MaximumMinimumGradCommon(scope, op, grad_inputs, grad_outputs, 499 comparator); 500} 501REGISTER_GRADIENT_OP("Maximum", MaximumGrad); 502 503Status MinimumGrad(const Scope& scope, const Operation& op, 504 const std::vector<Output>& grad_inputs, 505 std::vector<Output>* grad_outputs) { 506 auto comparator = LessEqual(scope, op.input(0), op.input(1)); 507 return MaximumMinimumGradCommon(scope, op, grad_inputs, grad_outputs, 508 comparator); 509} 510REGISTER_GRADIENT_OP("Minimum", MinimumGrad); 511 512Status RealGrad(const Scope& scope, const Operation& op, 513 const std::vector<Output>& grad_inputs, 514 std::vector<Output>* grad_outputs) { 515 auto zero = Cast(scope, Const(scope, 0.0), op.output(0).type()); 516 auto dx = Complex(scope, grad_inputs[0], zero); 517 grad_outputs->push_back(dx); 518 return scope.status(); 519} 520REGISTER_GRADIENT_OP("Real", RealGrad); 521 522Status ImagGrad(const Scope& scope, const Operation& op, 523 const std::vector<Output>& grad_inputs, 524 std::vector<Output>* grad_outputs) { 525 auto zero = Cast(scope, Const(scope, 0.0), op.output(0).type()); 526 auto dx = Complex(scope, zero, grad_inputs[0]); 527 grad_outputs->push_back(dx); 528 return scope.status(); 529} 530REGISTER_GRADIENT_OP("Imag", ImagGrad); 531 532Status ComplexGrad(const Scope& scope, const Operation& op, 533 const std::vector<Output>& grad_inputs, 534 std::vector<Output>* grad_outputs) { 535 auto gx_1 = Real(scope, grad_inputs[0]); 536 auto gx_2 = Imag(scope, grad_inputs[0]); 537 return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2); 538} 539REGISTER_GRADIENT_OP("Complex", ComplexGrad); 540 541Status AngleGrad(const Scope& scope, const Operation& op, 542 const std::vector<Output>& grad_inputs, 543 std::vector<Output>* grad_outputs) { 544 // y = Angle(x) 545 // dx = -dy / (Im(x) + iRe(x)) = -dy * z 546 auto re = Real(scope, op.input(0)); 547 auto im = Imag(scope, op.input(0)); 548 auto z_inv = Reciprocal(scope, Complex(scope, im, re)); 549 auto zero = Cast(scope, Const(scope, 0), grad_inputs[0].type()); 550 auto grad = Complex(scope, grad_inputs[0], zero); 551 auto dx = Neg(scope, Mul(scope, grad, z_inv)); 552 grad_outputs->push_back(dx); 553 return scope.status(); 554} 555REGISTER_GRADIENT_OP("Angle", AngleGrad); 556 557Status ConjGrad(const Scope& scope, const Operation& op, 558 const std::vector<Output>& grad_inputs, 559 std::vector<Output>* grad_outputs) { 560 grad_outputs->push_back(Conj(scope, grad_inputs[0])); 561 return scope.status(); 562} 563REGISTER_GRADIENT_OP("Conj", ConjGrad); 564 565// Integer division x / y, assuming x and y >=0, but treats x/0 = x 566Output SafeDivHelper(const Scope& scope, const Output& x, const Output& y) { 567 return Div(scope, x, Maximum(scope, y, Const(scope, 1))); 568} 569 570// Helper function for reduction ops. 571// 572// input_shape: 1-D Tensor, the shape of the Tensor being reduced. 573// axes: 1-D Tensor, the reduction axes. 574// Note that the reduction indices are in the range 575// -rank(input_shape), rank(input_shape) 576// returns a 1-D Tensor, the output shape as if keep_dims were set to True. 577Output ReducedShapeHelper(const Scope& scope, const Output& input_shape, 578 const Output& reduction_axes) { 579 auto zero = Const(scope, 0); 580 auto one = Const(scope, 1); 581 582 // Running example in comments 583 // input_shape = [2, 3, 5, 7] 584 // axes = [1, 2] 585 // The result (a shape after a reduction with keep_dims=True) 586 // [2, 1, 1, 7] 587 // 588 // We can treat each entry in axes as an index into input_shape that 589 // should be replaced by 1. 590 // We use DynamicStitch to do this. 591 592 // input_rank = 4 593 auto input_rank = Size(scope, input_shape); 594 595 // Normalize any negative indices in the reduction_axes to positive 596 // values. 597 auto axes = Mod(scope, Add(scope, reduction_axes, input_rank), input_rank); 598 599 // This [0..input_rank) range of integers is used in DynamicStitch to 600 // first copy input_shape to the result. 601 // input_rank_range = [0, 1, 2, 3] 602 auto input_rank_range = Range(scope, zero, input_rank, one); 603 604 // A 1-filled tensor with the same shape as axes. DynamicStitch will 605 // merge these 1s (using axes for indices) to the correct 606 // position in the result. 607 // axes_ones = [1, 1] 608 auto axes_ones = OnesLike(scope, axes); 609 610 // using DynamicStitch: 611 // indices = { input_rank_range, axes } 612 // = { [0, 1, 2, 3], [1, 2] } 613 // data = { input_shape, axes_ones } 614 // = { [2, 3, 5, 7], [1, 1] } 615 // The input_rank_range entry in indices first replicates the 616 // input_shape to the result. 617 // The axes entry in indices then moves a 1 to each of its entries, 618 // resulting in 619 // [2, 1, 1, 7] 620 std::vector<Output> indices = {input_rank_range, axes}; 621 std::vector<Output> data = {input_shape, axes_ones}; 622 return DynamicStitch(scope, indices, data); 623} 624 625// SumGradHelper returns the gradient for the Sum operator, and is used 626// by SumGrad and MeanGrad. 627Output SumGradHelper(const Scope& scope, const Operation& op, 628 const std::vector<Output>& grad_inputs) { 629 // The partial derivative for any input along a "reduced" dimension 630 // is just 1, so we only need replicate the output gradient on such a 631 // dimension to its "expanded" shape. 632 // Running example: 633 // input is 634 // [[a, b, c], 635 // [d, e, f]] 636 // reduction_indices = [1] 637 // Sum = [a + b + c, d + e + f] 638 // if the gradient is [g1, g2] 639 // We want the propagated gradient to be 640 // [[g1, g1, g1], 641 // [g2, g2, g2]] 642 643 // input_shape = [2, 3] 644 auto input_shape = Shape(scope, op.input(0)); 645 646 // output_shape_kept_dims = [2, 1] 647 auto output_shape_kept_dims = 648 ReducedShapeHelper(scope, input_shape, op.input(1)); 649 650 // This step "flips" any 1s with values from the input_shape, and 651 // replaces remaining entries with 1. This creates a shape that 652 // shows how much each dimension in the incoming gradient should be 653 // replicated. 654 // tile_scaling = [1, 3] 655 auto tile_scaling = SafeDivHelper(scope, input_shape, output_shape_kept_dims); 656 657 // grad = [[g1], [g2]] 658 auto grad = Reshape(scope, grad_inputs[0], output_shape_kept_dims); 659 660 // tile(grad, tile_scaling) = [[g1, g1, g1], [g2, g2, g2]] 661 return Tile(scope, grad, tile_scaling); 662} 663 664Status SumGrad(const Scope& scope, const Operation& op, 665 const std::vector<Output>& grad_inputs, 666 std::vector<Output>* grad_outputs) { 667 grad_outputs->push_back(SumGradHelper(scope, op, grad_inputs)); 668 669 // Stop propagation along reduction_indices 670 grad_outputs->push_back(NoGradient()); 671 return scope.status(); 672} 673REGISTER_GRADIENT_OP("Sum", SumGrad); 674 675Status MeanGrad(const Scope& scope, const Operation& op, 676 const std::vector<Output>& grad_inputs, 677 std::vector<Output>* grad_outputs) { 678 // The Mean gradient is just like the Sum gradient, except that 679 // all gradients are also divided by the size of reduced groups. 680 auto sum_grad = SumGradHelper(scope, op, grad_inputs); 681 682 // The product of all entries in a tensor's shape is the total 683 // number of entries in the tensor. This step calculates 684 // n_input_entries/n_output_entries 685 // = group_size 686 auto input_shape = Shape(scope, op.input(0)); 687 auto output_shape = Shape(scope, op.output(0)); 688 auto zero = Const(scope, 0); 689 auto group_size = SafeDivHelper(scope, Prod(scope, input_shape, zero), 690 Prod(scope, output_shape, zero)); 691 692 // propagate sum_grad/group_size 693 grad_outputs->push_back( 694 Div(scope, sum_grad, Cast(scope, group_size, sum_grad.type()))); 695 696 // Stop propagation along reduction_indices 697 grad_outputs->push_back(NoGradient()); 698 return scope.status(); 699} 700REGISTER_GRADIENT_OP("Mean", MeanGrad); 701 702Status ErfGrad(const Scope& scope, const Operation& op, 703 const std::vector<Output>& grad_inputs, 704 std::vector<Output>* grad_outputs) { 705 auto grad = grad_inputs[0]; 706 auto two_over_root_pi = Cast(scope, Const(scope, 2 / std::sqrt(M_PI)), 707 grad.type()); 708 Scope grad_scope = scope.WithControlDependencies(grad); 709 auto x = ConjugateHelper(grad_scope, op.input(0)); 710 // grad * 2/sqrt(pi) * exp(-x**2) 711 auto dx = Mul(grad_scope, 712 Mul(grad_scope, grad, two_over_root_pi), 713 Exp(grad_scope, Neg(grad_scope, Square(grad_scope, x)))); 714 grad_outputs->push_back(dx); 715 return grad_scope.status(); 716} 717REGISTER_GRADIENT_OP("Erf", ErfGrad); 718 719Status LgammaGrad(const Scope& scope, const Operation& op, 720 const std::vector<Output>& grad_inputs, 721 std::vector<Output>* grad_outputs) { 722 auto grad = grad_inputs[0]; 723 Scope grad_scope = scope.WithControlDependencies(grad); 724 auto x = ConjugateHelper(grad_scope, op.input(0)); 725 auto dx = Mul(grad_scope, grad, Digamma(grad_scope, x)); 726 grad_outputs->push_back(dx); 727 return grad_scope.status(); 728} 729REGISTER_GRADIENT_OP("Lgamma", LgammaGrad); 730 731Status MinOrMaxGrad(const Scope& scope, const Operation& op, 732 const std::vector<Output>& grad_inputs, 733 std::vector<Output>* grad_outputs) { 734 // The partial derivative for any input along a "reduced" dimension 735 // is 1 when it is the min (or max) and 0 everywhere else. So the 736 // gradient calculation is identical for both operators. 737 // 738 // There's a special case for propagating gradients when there are 739 // multiple minima (or maxima) - we choose to divide the gradient 740 // equally among all matching inputs. 741 // 742 // Please note this comment 743 // https://github.com/tensorflow/tensorflow/issues/4886#issuecomment-256836063 744 // for details. 745 746 // Running example: 747 // input: [[5, 5, 5], 748 // [1, 2, -3]] 749 // reduction_indices: [1] 750 auto input = op.input(0); 751 auto reduction_indices = op.input(1); 752 753 // [2, 3] 754 auto input_shape = Shape(scope, input); 755 756 // [2, 1] 757 auto output_shape_kept_dims = 758 ReducedShapeHelper(scope, input_shape, reduction_indices); 759 760 // for op=min (say) 761 // output = [5, -3] 762 // y = [[5], 763 // [-3]] 764 auto y = Reshape(scope, op.output(0), output_shape_kept_dims); 765 766 // reshape([g1, g2], [2, 1]) = [[g1], 767 // [g2]] 768 auto grad = Reshape(scope, grad_inputs[0], output_shape_kept_dims); 769 770 // indicators = equal(y, input) 771 // = equal([[5], [[5, 5, 5], 772 // [-3]], [1, 2, -3]]) 773 // = [[1, 1, 1], 774 // [0, 0, 1]] 775 auto indicators = Cast(scope, Equal(scope, y, input), grad_inputs[0].type()); 776 777 // [[3], 778 // [1]] 779 auto num_selected = Reshape(scope, Sum(scope, indicators, reduction_indices), 780 output_shape_kept_dims); 781 782 // [[1/3, 1/3, 1/3], 783 // [0, 0, 1]] 784 auto scale = Div(scope, indicators, num_selected); 785 786 // [[g1/3, g1/3, g1/3], 787 // [0, 0, g2]] 788 grad_outputs->push_back(Mul(scope, scale, grad)); 789 790 // Stop propagation along reduction_indices 791 grad_outputs->push_back(NoGradient()); 792 return scope.status(); 793} 794REGISTER_GRADIENT_OP("Min", MinOrMaxGrad); 795REGISTER_GRADIENT_OP("Max", MinOrMaxGrad); 796 797// MatMulGrad helper function used to compute two MatMul operations 798// based on input matrix transposition combinations. 799Status MatMulGradHelper(const Scope& scope, const bool is_batch, 800 const Output& x0, const bool adj_x0, const Output& x1, 801 const bool adj_x1, const Output& y0, const bool adj_y0, 802 const Output& y1, const bool adj_y1, 803 std::vector<Output>* grad_outputs) { 804 if (is_batch == false) { 805 auto dx = 806 MatMul(scope, x0, x1, MatMul::TransposeA(adj_x0).TransposeB(adj_x1)); 807 grad_outputs->push_back(dx); 808 auto dy = 809 MatMul(scope, y0, y1, MatMul::TransposeA(adj_y0).TransposeB(adj_y1)); 810 grad_outputs->push_back(dy); 811 } else { 812 auto dx = 813 BatchMatMul(scope, x0, x1, BatchMatMul::AdjX(adj_x0).AdjY(adj_x1)); 814 grad_outputs->push_back(dx); 815 auto dy = 816 BatchMatMul(scope, y0, y1, BatchMatMul::AdjX(adj_y0).AdjY(adj_y1)); 817 grad_outputs->push_back(dy); 818 } 819 return scope.status(); 820} 821 822// MatMulGrad common used to read and check node attr state, and determine 823// proper MatMul products for gradients based on input matrix transposition 824// combinations. 825Status MatMulGradCommon(const Scope& scope, const Operation& op, 826 const bool is_batch, 827 const std::vector<Output>& grad_inputs, 828 const string& attr_adj_x, const string& attr_adj_y, 829 std::vector<Output>* grad_outputs) { 830 auto a = op.input(0); 831 auto b = op.input(1); 832 // Use conjugate of the inputs for MatMul 833 if (is_batch == false) { 834 a = ConjugateHelper(scope, a); 835 b = ConjugateHelper(scope, b); 836 } 837 auto product = op.output(0); 838 839 bool ta; 840 bool tb; 841 TF_RETURN_IF_ERROR(GetNodeAttr(product.node()->attrs(), attr_adj_x, &ta)); 842 TF_RETURN_IF_ERROR(GetNodeAttr(product.node()->attrs(), attr_adj_y, &tb)); 843 844 if (!ta && !tb) { 845 return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, b, true, a, 846 true, grad_inputs[0], false, grad_outputs); 847 } else if (!ta && tb) { 848 return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, b, false, 849 grad_inputs[0], true, a, false, grad_outputs); 850 } else if (ta && !tb) { 851 return MatMulGradHelper(scope, is_batch, b, false, grad_inputs[0], true, a, 852 false, grad_inputs[0], false, grad_outputs); 853 } 854 return MatMulGradHelper(scope, is_batch, b, true, grad_inputs[0], true, 855 grad_inputs[0], true, a, true, grad_outputs); 856} 857 858Status MatMulGrad(const Scope& scope, const Operation& op, 859 const std::vector<Output>& grad_inputs, 860 std::vector<Output>* grad_outputs) { 861 return MatMulGradCommon(scope, op, false, grad_inputs, "transpose_a", 862 "transpose_b", grad_outputs); 863} 864REGISTER_GRADIENT_OP("MatMul", MatMulGrad); 865 866Status BatchMatMulGrad(const Scope& scope, const Operation& op, 867 const std::vector<Output>& grad_inputs, 868 std::vector<Output>* grad_outputs) { 869 return MatMulGradCommon(scope, op, true, grad_inputs, "adj_x", "adj_y", 870 grad_outputs); 871} 872REGISTER_GRADIENT_OP("BatchMatMul", BatchMatMulGrad); 873 874} // anonymous namespace 875} // namespace ops 876} // namespace tensorflow 877