math_grad.cc revision 20765b3e1ae3b718699592c98aa9805cb874b6d1
1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#define _USE_MATH_DEFINES 17#include <cmath> 18 19#include "tensorflow/cc/ops/array_ops_internal.h" 20#include "tensorflow/cc/ops/math_ops_internal.h" 21#include "tensorflow/cc/ops/standard_ops.h" 22 23#include "tensorflow/cc/framework/grad_op_registry.h" 24#include "tensorflow/cc/framework/gradients.h" 25 26namespace tensorflow { 27namespace ops { 28namespace { 29 30// Logical operations have no gradients. 31REGISTER_NO_GRADIENT_OP("Less"); 32REGISTER_NO_GRADIENT_OP("LessEqual"); 33REGISTER_NO_GRADIENT_OP("Greater"); 34REGISTER_NO_GRADIENT_OP("GreaterEqual"); 35REGISTER_NO_GRADIENT_OP("Equal"); 36REGISTER_NO_GRADIENT_OP("ApproximateEqual"); 37REGISTER_NO_GRADIENT_OP("NotEqual"); 38REGISTER_NO_GRADIENT_OP("LogicalAnd"); 39REGISTER_NO_GRADIENT_OP("LogicalOr"); 40REGISTER_NO_GRADIENT_OP("LogicalNot"); 41 42// Conjugate helper function returns the conjugate of an Output if it 43// is complex valued. 44Output ConjugateHelper(const Scope& scope, const Output& out) { 45 DataType dtype = out.type(); 46 if (dtype == DT_COMPLEX64 || dtype == DT_COMPLEX128) { 47 return Conj(scope, out); 48 } else { 49 return out; 50 } 51} 52 53// TODO(andydavis) Add control dependencies to gradient functions (as needed). 54 55Status AbsGrad(const Scope& scope, const Operation& op, 56 const std::vector<Output>& grad_inputs, 57 std::vector<Output>* grad_outputs) { 58 // dx = dy * sign(x) 59 grad_outputs->push_back(Mul(scope, grad_inputs[0], Sign(scope, op.input(0)))); 60 return scope.status(); 61} 62REGISTER_GRADIENT_OP("Abs", AbsGrad); 63 64Status NegGrad(const Scope& scope, const Operation& op, 65 const std::vector<Output>& grad_inputs, 66 std::vector<Output>* grad_outputs) { 67 // dx = -dy; 68 grad_outputs->push_back(Neg(scope, grad_inputs[0])); 69 return scope.status(); 70} 71REGISTER_GRADIENT_OP("Neg", NegGrad); 72 73Status InvGrad(const Scope& scope, const Operation& op, 74 const std::vector<Output>& grad_inputs, 75 std::vector<Output>* grad_outputs) { 76 // Use the built-in operator. 77 grad_outputs->push_back( 78 internal::ReciprocalGrad(scope, op.output(0), grad_inputs[0])); 79 return scope.status(); 80} 81REGISTER_GRADIENT_OP("Inv", InvGrad); 82REGISTER_GRADIENT_OP("Reciprocal", InvGrad); 83 84Status SquareGrad(const Scope& scope, const Operation& op, 85 const std::vector<Output>& grad_inputs, 86 std::vector<Output>* grad_outputs) { 87 // dy/dx = (2 * x) 88 auto two = Cast(scope, Const(scope, 2), op.input(0).type()); 89 auto dydx = Mul(scope, two, op.input(0)); 90 // grad(x) = grad(y) * conj(dy/dx) 91 grad_outputs->push_back( 92 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 93 return scope.status(); 94} 95REGISTER_GRADIENT_OP("Square", SquareGrad); 96 97Status SqrtGrad(const Scope& scope, const Operation& op, 98 const std::vector<Output>& grad_inputs, 99 std::vector<Output>* grad_outputs) { 100 // Use the built-in operator. 101 grad_outputs->push_back( 102 internal::SqrtGrad(scope, op.output(0), grad_inputs[0])); 103 return scope.status(); 104} 105REGISTER_GRADIENT_OP("Sqrt", SqrtGrad); 106 107Status RsqrtGrad(const Scope& scope, const Operation& op, 108 const std::vector<Output>& grad_inputs, 109 std::vector<Output>* grad_outputs) { 110 // Use the built-in operator. 111 grad_outputs->push_back( 112 internal::RsqrtGrad(scope, op.output(0), grad_inputs[0])); 113 return scope.status(); 114} 115REGISTER_GRADIENT_OP("Rsqrt", RsqrtGrad); 116 117Status ExpGrad(const Scope& scope, const Operation& op, 118 const std::vector<Output>& grad_inputs, 119 std::vector<Output>* grad_outputs) { 120 // dy/dx = exp(x) = y 121 // grad(x) = grad(y) * conj(dy/dx) 122 // = grad(y) * conj(y) 123 grad_outputs->push_back( 124 Mul(scope, grad_inputs[0], ConjugateHelper(scope, op.output(0)))); 125 return scope.status(); 126} 127REGISTER_GRADIENT_OP("Exp", ExpGrad); 128 129Status Expm1Grad(const Scope& scope, const Operation& op, 130 const std::vector<Output>& grad_inputs, 131 std::vector<Output>* grad_outputs) { 132 // y = expm1(x) 133 // dy/dx = exp(x) 134 auto dydx = Exp(scope, op.input(0)); 135 // grad(x) = grad(y) * conj(dy/dx) 136 grad_outputs->push_back( 137 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 138 return scope.status(); 139} 140REGISTER_GRADIENT_OP("Expm1", Expm1Grad); 141 142Status LogGrad(const Scope& scope, const Operation& op, 143 const std::vector<Output>& grad_inputs, 144 std::vector<Output>* grad_outputs) { 145 // y = log(x) 146 // dy/dx = 1 / x 147 auto dydx = Reciprocal(scope, op.input(0)); 148 // grad(x) = grad(y) * conj(dy/dx) 149 grad_outputs->push_back( 150 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 151 return scope.status(); 152} 153REGISTER_GRADIENT_OP("Log", LogGrad); 154 155Status Log1pGrad(const Scope& scope, const Operation& op, 156 const std::vector<Output>& grad_inputs, 157 std::vector<Output>* grad_outputs) { 158 // y = log1p(x) 159 // dy/dx = 1 / (1 + x) 160 auto one = Cast(scope, Const(scope, 1.0), op.input(0).type()); 161 auto dydx = Reciprocal(scope, Add(scope, one, op.input(0))); 162 // grad(x) = grad(y) * conj(dy/dx) 163 grad_outputs->push_back( 164 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 165 return scope.status(); 166} 167REGISTER_GRADIENT_OP("Log1p", Log1pGrad); 168 169Status SinhGrad(const Scope& scope, const Operation& op, 170 const std::vector<Output>& grad_inputs, 171 std::vector<Output>* grad_outputs) { 172 // y = sinh(x) 173 // dy/dx = cosh(x) 174 auto dydx = Cosh(scope, op.input(0)); 175 // grad(x) = grad(y) * conj(dy/dx) 176 grad_outputs->push_back( 177 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 178 return scope.status(); 179} 180REGISTER_GRADIENT_OP("Sinh", SinhGrad); 181 182Status CoshGrad(const Scope& scope, const Operation& op, 183 const std::vector<Output>& grad_inputs, 184 std::vector<Output>* grad_outputs) { 185 // y = cosh(x) 186 // dy/dx = sinh(x) 187 auto dydx = Sinh(scope, op.input(0)); 188 // grad(x) = grad(y) * conj(dy/dx) 189 grad_outputs->push_back( 190 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 191 return scope.status(); 192} 193REGISTER_GRADIENT_OP("Cosh", CoshGrad); 194 195Status TanhGrad(const Scope& scope, const Operation& op, 196 const std::vector<Output>& grad_inputs, 197 std::vector<Output>* grad_outputs) { 198 // Use the built-in operator. 199 // Note that the built-in operator does not return the conjugate of 200 // the gradient. 201 auto grad = grad_inputs[0]; 202 // Optimization to avoid calculating conj(y) until the gradient is 203 // evaluated. 204 Scope grad_scope = scope.WithControlDependencies(grad); 205 auto y = ConjugateHelper(grad_scope, op.output(0)); 206 grad_outputs->push_back(internal::TanhGrad(grad_scope, y, grad)); 207 return grad_scope.status(); 208} 209REGISTER_GRADIENT_OP("Tanh", TanhGrad); 210 211Status AsinhGrad(const Scope& scope, const Operation& op, 212 const std::vector<Output>& grad_inputs, 213 std::vector<Output>* grad_outputs) { 214 // y = asinh(x) 215 // dy/dx = 1 / cosh(y) 216 auto dydx = Reciprocal(scope, Cosh(scope, op.output(0))); 217 // grad(x) = grad(y) * conj(dy/dx) 218 grad_outputs->push_back( 219 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 220 return scope.status(); 221} 222REGISTER_GRADIENT_OP("Asinh", AsinhGrad); 223 224Status AcoshGrad(const Scope& scope, const Operation& op, 225 const std::vector<Output>& grad_inputs, 226 std::vector<Output>* grad_outputs) { 227 // y = acosh(x) 228 // dy/dx = 1 / sinh(y) 229 auto dydx = Reciprocal(scope, Sinh(scope, op.output(0))); 230 // grad(x) = grad(y) * conj(dy/dx) 231 grad_outputs->push_back( 232 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 233 return scope.status(); 234} 235REGISTER_GRADIENT_OP("Acosh", AcoshGrad); 236 237Status AtanhGrad(const Scope& scope, const Operation& op, 238 const std::vector<Output>& grad_inputs, 239 std::vector<Output>* grad_outputs) { 240 // y = atanh(x) 241 // dy/dx = 1 / (1 - x^2) 242 auto one = Cast(scope, Const(scope, 1.0), op.input(0).type()); 243 auto dydx = Reciprocal(scope, Sub(scope, one, Square(scope, op.input(0)))); 244 // grad(x) = grad(y) * conj(dy/dx) 245 grad_outputs->push_back( 246 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 247 return scope.status(); 248} 249REGISTER_GRADIENT_OP("Atanh", AtanhGrad); 250 251Status SigmoidGrad(const Scope& scope, const Operation& op, 252 const std::vector<Output>& grad_inputs, 253 std::vector<Output>* grad_outputs) { 254 // Use the built-in operator. 255 // Note that the built-in operator does not return the conjugate of 256 // the gradient. 257 auto grad = grad_inputs[0]; 258 // Optimization to avoid calculating conj(y) until the gradient is 259 // evaluated. 260 Scope grad_scope = scope.WithControlDependencies(grad); 261 auto y = ConjugateHelper(grad_scope, op.output(0)); 262 grad_outputs->push_back(internal::SigmoidGrad(grad_scope, y, grad)); 263 return grad_scope.status(); 264} 265REGISTER_GRADIENT_OP("Sigmoid", SigmoidGrad); 266 267Status SignGrad(const Scope& scope, const Operation& op, 268 const std::vector<Output>& grad_inputs, 269 std::vector<Output>* grad_outputs) { 270 auto shape = Shape(scope, op.input(0)); 271 auto zero = Cast(scope, Const(scope, 0.0), op.input(0).type()); 272 auto dx = Fill(scope, shape, zero); 273 grad_outputs->push_back(dx); 274 return scope.status(); 275} 276REGISTER_GRADIENT_OP("Sign", SignGrad); 277 278Status SinGrad(const Scope& scope, const Operation& op, 279 const std::vector<Output>& grad_inputs, 280 std::vector<Output>* grad_outputs) { 281 // y = sin(x) 282 // dy/dx = cos(x) 283 auto dydx = Cos(scope, op.input(0)); 284 // grad(x) = grad(y) * conj(dy/dx) 285 grad_outputs->push_back( 286 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 287 return scope.status(); 288} 289REGISTER_GRADIENT_OP("Sin", SinGrad); 290 291Status CosGrad(const Scope& scope, const Operation& op, 292 const std::vector<Output>& grad_inputs, 293 std::vector<Output>* grad_outputs) { 294 // y = cos(x) 295 // dy/dx = -sin(x) 296 auto dydx = Neg(scope, Sin(scope, op.input(0))); 297 // grad(x) = grad(y) * conj(dy/dx) 298 grad_outputs->push_back( 299 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 300 return scope.status(); 301} 302REGISTER_GRADIENT_OP("Cos", CosGrad); 303 304Status AsinGrad(const Scope& scope, const Operation& op, 305 const std::vector<Output>& grad_inputs, 306 std::vector<Output>* grad_outputs) { 307 // y = asin(x) 308 // dy/dx = 1 / sqrt(1 - x^2) 309 auto x2 = Square(scope, op.input(0)); 310 auto one = Cast(scope, Const(scope, 1.0), op.input(0).type()); 311 auto dydx = Reciprocal(scope, Sqrt(scope, Sub(scope, one, x2))); 312 // grad(x) = grad(y) * conj(dy/dx) 313 auto dx = Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)); 314 grad_outputs->push_back(dx); 315 return scope.status(); 316} 317REGISTER_GRADIENT_OP("Asin", AsinGrad); 318 319Status AcosGrad(const Scope& scope, const Operation& op, 320 const std::vector<Output>& grad_inputs, 321 std::vector<Output>* grad_outputs) { 322 // y = acos(x) 323 // dy/dx = - 1 / (1 - x * x)^1/2 324 // dx = dy * (- 1 / (1 - x * x)^1/2) 325 auto x2 = Square(scope, op.input(0)); 326 auto one = Cast(scope, Const(scope, 1.0), op.input(0).type()); 327 auto dydx = Neg(scope, Reciprocal(scope, Sqrt(scope, Sub(scope, one, x2)))); 328 auto dx = Mul(scope, grad_inputs[0], dydx); 329 grad_outputs->push_back(dx); 330 return scope.status(); 331} 332REGISTER_GRADIENT_OP("Acos", AcosGrad); 333 334Status TanGrad(const Scope& scope, const Operation& op, 335 const std::vector<Output>& grad_inputs, 336 std::vector<Output>* grad_outputs) { 337 // y = tan(x) 338 // dy/dx = sec(x)^2 = 1 / cos(x)^2 339 auto dydx = Square(scope, Reciprocal(scope, Cos(scope, op.input(0)))); 340 // grad(x) = grad(y) * conj(dy/dx) 341 auto dx = Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)); 342 grad_outputs->push_back(dx); 343 return scope.status(); 344} 345REGISTER_GRADIENT_OP("Tan", TanGrad); 346 347Status AtanGrad(const Scope& scope, const Operation& op, 348 const std::vector<Output>& grad_inputs, 349 std::vector<Output>* grad_outputs) { 350 // y = arctan(x) 351 // dy/dx = 1 / (1 + x^2) 352 // dx = dy * (1 / (1 + x^2) 353 auto one = Cast(scope, Const(scope, 1.0), op.input(0).type()); 354 auto dydx = Reciprocal(scope, Add(scope, one, Square(scope, op.input(0)))); 355 auto dx = Mul(scope, grad_inputs[0], dydx); 356 grad_outputs->push_back(dx); 357 return scope.status(); 358} 359REGISTER_GRADIENT_OP("Atan", AtanGrad); 360 361// BinaryGradCommon handles the setup for binary ops that broadcast 362// their inputs. 363Status BinaryGradCommon(const Scope& scope, const Operation& op, 364 std::vector<Output>* grad_outputs, const Output& gx_1, 365 const Output& gx_2) { 366 auto sx_1 = Shape(scope, op.input(0)); 367 auto sx_2 = Shape(scope, op.input(1)); 368 auto rx = internal::BroadcastGradientArgs(scope, sx_1, sx_2); 369 auto dx_1 = Reshape(scope, Sum(scope, gx_1, rx.r0), sx_1); 370 auto dx_2 = Reshape(scope, Sum(scope, gx_2, rx.r1), sx_2); 371 grad_outputs->push_back(dx_1); 372 grad_outputs->push_back(dx_2); 373 return scope.status(); 374} 375 376Status AddGrad(const Scope& scope, const Operation& op, 377 const std::vector<Output>& grad_inputs, 378 std::vector<Output>* grad_outputs) { 379 // y = x_1 + x_2 380 // dy/dx_1 = dy/dx_2 = 1 381 auto gx_1 = Identity(scope, grad_inputs[0]); 382 auto gx_2 = Identity(scope, grad_inputs[0]); 383 return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2); 384} 385REGISTER_GRADIENT_OP("Add", AddGrad); 386 387Status SubGrad(const Scope& scope, const Operation& op, 388 const std::vector<Output>& grad_inputs, 389 std::vector<Output>* grad_outputs) { 390 // y = x_1 - x_2 391 // dy/dx_1 = 1 392 // dy/dx_2 = -1 393 auto gx_1 = Identity(scope, grad_inputs[0]); 394 auto gx_2 = Neg(scope, grad_inputs[0]); 395 return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2); 396} 397REGISTER_GRADIENT_OP("Sub", SubGrad); 398 399Status MulGrad(const Scope& scope, const Operation& op, 400 const std::vector<Output>& grad_inputs, 401 std::vector<Output>* grad_outputs) { 402 auto x_1 = ConjugateHelper(scope, op.input(0)); 403 auto x_2 = ConjugateHelper(scope, op.input(1)); 404 // y = x_1 * x_2 405 // dy/dx_1 = x_2 406 // dy/dx_2 = x_1 407 auto gx_1 = Mul(scope, grad_inputs[0], x_2); 408 auto gx_2 = Mul(scope, grad_inputs[0], x_1); 409 return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2); 410} 411REGISTER_GRADIENT_OP("Mul", MulGrad); 412 413Status DivGrad(const Scope& scope, const Operation& op, 414 const std::vector<Output>& grad_inputs, 415 std::vector<Output>* grad_outputs) { 416 auto x_1 = ConjugateHelper(scope, op.input(0)); 417 auto x_2 = ConjugateHelper(scope, op.input(1)); 418 // y = x_1 / x_2 419 // dy/dx_1 = 1/x_2 420 // dy/dx_2 = -x_1/x_2^2 421 auto gx_1 = Div(scope, grad_inputs[0], x_2); 422 auto gx_2 = Mul(scope, grad_inputs[0], 423 Div(scope, Div(scope, Neg(scope, x_1), x_2), x_2)); 424 return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2); 425} 426REGISTER_GRADIENT_OP("Div", DivGrad); 427 428Status RealDivGrad(const Scope& scope, const Operation& op, 429 const std::vector<Output>& grad_inputs, 430 std::vector<Output>* grad_outputs) { 431 auto x_1 = ConjugateHelper(scope, op.input(0)); 432 auto x_2 = ConjugateHelper(scope, op.input(1)); 433 // y = x_1 / x_2 434 // dy/dx_1 = 1/x_2 435 // dy/dx_2 = -x_1/x_2^2 436 auto gx_1 = RealDiv(scope, grad_inputs[0], x_2); 437 auto gx_2 = Mul(scope, grad_inputs[0], 438 RealDiv(scope, RealDiv(scope, Neg(scope, x_1), x_2), x_2)); 439 return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2); 440} 441REGISTER_GRADIENT_OP("RealDiv", RealDivGrad); 442 443Status SquaredDifferenceGrad(const Scope& scope, const Operation& op, 444 const std::vector<Output>& grad_inputs, 445 std::vector<Output>* grad_outputs) { 446 auto x_1 = ConjugateHelper(scope, op.input(0)); 447 auto x_2 = ConjugateHelper(scope, op.input(1)); 448 // y = (x_1 - x_2)^2 449 // dy/dx_1 = 2 * (x_1 - x_2) 450 // dy/dx_2 = -2 * (x_1 - x_2) 451 auto two = Cast(scope, Const(scope, 2), grad_inputs[0].type()); 452 auto gx_1 = Mul(scope, grad_inputs[0], Mul(scope, two, Sub(scope, x_1, x_2))); 453 auto gx_2 = Neg(scope, gx_1); 454 return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2); 455} 456REGISTER_GRADIENT_OP("SquaredDifference", SquaredDifferenceGrad); 457 458Status AddNGrad(const Scope& scope, const Operation& op, 459 const std::vector<Output>& grad_inputs, 460 std::vector<Output>* grad_outputs) { 461 // AddN doesn't support broadcasting, so all the inputs must be the 462 // same shape. 463 // Note: 464 // dy/dx_k = d(x_1 + x_2 + ... + x_n)/dx_k = 1 for all x_k 465 // hence dx_k = dy for all x_k 466 // So the gradient for AddN just transfers the incoming gradient to 467 // all outgoing gradients. 468 auto incoming = Identity(scope, grad_inputs[0]); 469 for (int32 i = 0; i < op.num_inputs(); ++i) { 470 grad_outputs->push_back(incoming); 471 } 472 return scope.status(); 473} 474REGISTER_GRADIENT_OP("AddN", AddNGrad); 475 476Status PowGrad(const Scope& scope, const Operation& op, 477 const std::vector<Output>& grad_inputs, 478 std::vector<Output>* grad_outputs) { 479 auto x = ConjugateHelper(scope, op.input(0)); 480 auto y = ConjugateHelper(scope, op.input(1)); 481 auto z = ConjugateHelper(scope, op.output(0)); 482 auto grad = grad_inputs[0]; 483 // grad * y * pow(x, y - 1) 484 auto one = Cast(scope, Const(scope, 1.0), y.type()); 485 auto gx_1 = Mul(scope, 486 Mul(scope, grad, y), 487 Pow(scope, x, Sub(scope, y, one))); 488 // Avoid false singularity at x = 0 489 DataType x_dtype = x.type(); 490 auto zero = Cast(scope, Const(scope, 0.0), x_dtype); 491 if (x_dtype == DT_COMPLEX64 || x_dtype == DT_COMPLEX128) { 492 // real(x) < 0 is fine for the complex case 493 auto log_x = Where3(scope, 494 NotEqual(scope, x, zero), 495 Log(scope, x), 496 ZerosLike(scope, x)); 497 auto gy_1 = Mul(scope, Mul(scope, grad, z), log_x); 498 return BinaryGradCommon(scope, op, grad_outputs, gx_1, gy_1); 499 } else { 500 // There's no sensible real value to return if x < 0, so return 0 501 auto log_x = Where3(scope, 502 Greater(scope, x, zero), 503 Log(scope, x), 504 ZerosLike(scope, x)); 505 auto gy_1 = Mul(scope, Mul(scope, grad, z), log_x); 506 return BinaryGradCommon(scope, op, grad_outputs, gx_1, gy_1); 507 } 508} 509REGISTER_GRADIENT_OP("Pow", PowGrad); 510 511// MaximumMinimumGradCommon adds shared ops to calculate gradients for 512// the binary Maximum and Minimum ops. 513Status MaximumMinimumGradCommon(const Scope& scope, const Operation& op, 514 const std::vector<Output>& grad_inputs, 515 std::vector<Output>* grad_outputs, 516 const Output& comparator) { 517 // comparator is a boolean tensor, with 518 // y = x_1 at points where comparator is true, and x_2 otherwise 519 // Therefore 520 // dy/dx_1 = 1 where comparator is true, and 0 otherwise. 521 // dy/dx_2 = 0 where comparator is true, and 1 otherwise. 522 auto grad = grad_inputs[0]; 523 auto zeros = ZerosLike(scope, grad); 524 auto gx_1 = Where3(scope, comparator, grad, zeros); 525 auto gx_2 = Where3(scope, comparator, zeros, grad); 526 return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2); 527} 528 529Status MaximumGrad(const Scope& scope, const Operation& op, 530 const std::vector<Output>& grad_inputs, 531 std::vector<Output>* grad_outputs) { 532 auto comparator = GreaterEqual(scope, op.input(0), op.input(1)); 533 return MaximumMinimumGradCommon(scope, op, grad_inputs, grad_outputs, 534 comparator); 535} 536REGISTER_GRADIENT_OP("Maximum", MaximumGrad); 537 538Status MinimumGrad(const Scope& scope, const Operation& op, 539 const std::vector<Output>& grad_inputs, 540 std::vector<Output>* grad_outputs) { 541 auto comparator = LessEqual(scope, op.input(0), op.input(1)); 542 return MaximumMinimumGradCommon(scope, op, grad_inputs, grad_outputs, 543 comparator); 544} 545REGISTER_GRADIENT_OP("Minimum", MinimumGrad); 546 547Status RealGrad(const Scope& scope, const Operation& op, 548 const std::vector<Output>& grad_inputs, 549 std::vector<Output>* grad_outputs) { 550 auto zero = Cast(scope, Const(scope, 0.0), op.output(0).type()); 551 auto dx = Complex(scope, grad_inputs[0], zero); 552 grad_outputs->push_back(dx); 553 return scope.status(); 554} 555REGISTER_GRADIENT_OP("Real", RealGrad); 556 557Status ImagGrad(const Scope& scope, const Operation& op, 558 const std::vector<Output>& grad_inputs, 559 std::vector<Output>* grad_outputs) { 560 auto zero = Cast(scope, Const(scope, 0.0), op.output(0).type()); 561 auto dx = Complex(scope, zero, grad_inputs[0]); 562 grad_outputs->push_back(dx); 563 return scope.status(); 564} 565REGISTER_GRADIENT_OP("Imag", ImagGrad); 566 567Status ComplexGrad(const Scope& scope, const Operation& op, 568 const std::vector<Output>& grad_inputs, 569 std::vector<Output>* grad_outputs) { 570 auto gx_1 = Real(scope, grad_inputs[0]); 571 auto gx_2 = Imag(scope, grad_inputs[0]); 572 return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2); 573} 574REGISTER_GRADIENT_OP("Complex", ComplexGrad); 575 576Status AngleGrad(const Scope& scope, const Operation& op, 577 const std::vector<Output>& grad_inputs, 578 std::vector<Output>* grad_outputs) { 579 // y = Angle(x) 580 // dx = -dy / (Im(x) + iRe(x)) = -dy * z 581 auto re = Real(scope, op.input(0)); 582 auto im = Imag(scope, op.input(0)); 583 auto z_inv = Reciprocal(scope, Complex(scope, im, re)); 584 auto zero = Cast(scope, Const(scope, 0), grad_inputs[0].type()); 585 auto grad = Complex(scope, grad_inputs[0], zero); 586 auto dx = Neg(scope, Mul(scope, grad, z_inv)); 587 grad_outputs->push_back(dx); 588 return scope.status(); 589} 590REGISTER_GRADIENT_OP("Angle", AngleGrad); 591 592Status ConjGrad(const Scope& scope, const Operation& op, 593 const std::vector<Output>& grad_inputs, 594 std::vector<Output>* grad_outputs) { 595 grad_outputs->push_back(Conj(scope, grad_inputs[0])); 596 return scope.status(); 597} 598REGISTER_GRADIENT_OP("Conj", ConjGrad); 599 600// Integer division x / y, assuming x and y >=0, but treats x/0 = x 601Output SafeDivHelper(const Scope& scope, const Output& x, const Output& y) { 602 return Div(scope, x, Maximum(scope, y, Const(scope, 1))); 603} 604 605// Helper function for reduction ops. 606// 607// input_shape: 1-D Tensor, the shape of the Tensor being reduced. 608// axes: 1-D Tensor, the reduction axes. 609// Note that the reduction indices are in the range 610// -rank(input_shape), rank(input_shape) 611// returns a 1-D Tensor, the output shape as if keep_dims were set to True. 612Output ReducedShapeHelper(const Scope& scope, const Output& input_shape, 613 const Output& reduction_axes) { 614 auto zero = Const(scope, 0); 615 auto one = Const(scope, 1); 616 617 // Running example in comments 618 // input_shape = [2, 3, 5, 7] 619 // axes = [1, 2] 620 // The result (a shape after a reduction with keep_dims=True) 621 // [2, 1, 1, 7] 622 // 623 // We can treat each entry in axes as an index into input_shape that 624 // should be replaced by 1. 625 // We use DynamicStitch to do this. 626 627 // input_rank = 4 628 auto input_rank = Size(scope, input_shape); 629 630 // Normalize any negative indices in the reduction_axes to positive 631 // values. 632 auto axes = Mod(scope, Add(scope, reduction_axes, input_rank), input_rank); 633 634 // This [0..input_rank) range of integers is used in DynamicStitch to 635 // first copy input_shape to the result. 636 // input_rank_range = [0, 1, 2, 3] 637 auto input_rank_range = Range(scope, zero, input_rank, one); 638 639 // A 1-filled tensor with the same shape as axes. DynamicStitch will 640 // merge these 1s (using axes for indices) to the correct 641 // position in the result. 642 // axes_ones = [1, 1] 643 auto axes_ones = OnesLike(scope, axes); 644 645 // using DynamicStitch: 646 // indices = { input_rank_range, axes } 647 // = { [0, 1, 2, 3], [1, 2] } 648 // data = { input_shape, axes_ones } 649 // = { [2, 3, 5, 7], [1, 1] } 650 // The input_rank_range entry in indices first replicates the 651 // input_shape to the result. 652 // The axes entry in indices then moves a 1 to each of its entries, 653 // resulting in 654 // [2, 1, 1, 7] 655 std::vector<Output> indices = {input_rank_range, axes}; 656 std::vector<Output> data = {input_shape, axes_ones}; 657 return DynamicStitch(scope, indices, data); 658} 659 660// SumGradHelper returns the gradient for the Sum operator, and is used 661// by SumGrad and MeanGrad. 662Output SumGradHelper(const Scope& scope, const Operation& op, 663 const std::vector<Output>& grad_inputs) { 664 // The partial derivative for any input along a "reduced" dimension 665 // is just 1, so we only need replicate the output gradient on such a 666 // dimension to its "expanded" shape. 667 // Running example: 668 // input is 669 // [[a, b, c], 670 // [d, e, f]] 671 // reduction_indices = [1] 672 // Sum = [a + b + c, d + e + f] 673 // if the gradient is [g1, g2] 674 // We want the propagated gradient to be 675 // [[g1, g1, g1], 676 // [g2, g2, g2]] 677 678 // input_shape = [2, 3] 679 auto input_shape = Shape(scope, op.input(0)); 680 681 // output_shape_kept_dims = [2, 1] 682 auto output_shape_kept_dims = 683 ReducedShapeHelper(scope, input_shape, op.input(1)); 684 685 // This step "flips" any 1s with values from the input_shape, and 686 // replaces remaining entries with 1. This creates a shape that 687 // shows how much each dimension in the incoming gradient should be 688 // replicated. 689 // tile_scaling = [1, 3] 690 auto tile_scaling = SafeDivHelper(scope, input_shape, output_shape_kept_dims); 691 692 // grad = [[g1], [g2]] 693 auto grad = Reshape(scope, grad_inputs[0], output_shape_kept_dims); 694 695 // tile(grad, tile_scaling) = [[g1, g1, g1], [g2, g2, g2]] 696 return Tile(scope, grad, tile_scaling); 697} 698 699Status SumGrad(const Scope& scope, const Operation& op, 700 const std::vector<Output>& grad_inputs, 701 std::vector<Output>* grad_outputs) { 702 grad_outputs->push_back(SumGradHelper(scope, op, grad_inputs)); 703 704 // Stop propagation along reduction_indices 705 grad_outputs->push_back(NoGradient()); 706 return scope.status(); 707} 708REGISTER_GRADIENT_OP("Sum", SumGrad); 709 710Status MeanGrad(const Scope& scope, const Operation& op, 711 const std::vector<Output>& grad_inputs, 712 std::vector<Output>* grad_outputs) { 713 // The Mean gradient is just like the Sum gradient, except that 714 // all gradients are also divided by the size of reduced groups. 715 auto sum_grad = SumGradHelper(scope, op, grad_inputs); 716 717 // The product of all entries in a tensor's shape is the total 718 // number of entries in the tensor. This step calculates 719 // n_input_entries/n_output_entries 720 // = group_size 721 auto input_shape = Shape(scope, op.input(0)); 722 auto output_shape = Shape(scope, op.output(0)); 723 auto zero = Const(scope, 0); 724 auto group_size = SafeDivHelper(scope, Prod(scope, input_shape, zero), 725 Prod(scope, output_shape, zero)); 726 727 // propagate sum_grad/group_size 728 grad_outputs->push_back( 729 Div(scope, sum_grad, Cast(scope, group_size, sum_grad.type()))); 730 731 // Stop propagation along reduction_indices 732 grad_outputs->push_back(NoGradient()); 733 return scope.status(); 734} 735REGISTER_GRADIENT_OP("Mean", MeanGrad); 736 737Status ErfGrad(const Scope& scope, const Operation& op, 738 const std::vector<Output>& grad_inputs, 739 std::vector<Output>* grad_outputs) { 740 auto grad = grad_inputs[0]; 741 auto two_over_root_pi = Cast(scope, Const(scope, 2 / std::sqrt(M_PI)), 742 grad.type()); 743 Scope grad_scope = scope.WithControlDependencies(grad); 744 auto x = ConjugateHelper(grad_scope, op.input(0)); 745 // grad * 2/sqrt(pi) * exp(-x**2) 746 auto dx = Mul(grad_scope, 747 Mul(grad_scope, grad, two_over_root_pi), 748 Exp(grad_scope, Neg(grad_scope, Square(grad_scope, x)))); 749 grad_outputs->push_back(dx); 750 return grad_scope.status(); 751} 752REGISTER_GRADIENT_OP("Erf", ErfGrad); 753 754Status LgammaGrad(const Scope& scope, const Operation& op, 755 const std::vector<Output>& grad_inputs, 756 std::vector<Output>* grad_outputs) { 757 auto grad = grad_inputs[0]; 758 Scope grad_scope = scope.WithControlDependencies(grad); 759 auto x = ConjugateHelper(grad_scope, op.input(0)); 760 auto dx = Mul(grad_scope, grad, Digamma(grad_scope, x)); 761 grad_outputs->push_back(dx); 762 return grad_scope.status(); 763} 764REGISTER_GRADIENT_OP("Lgamma", LgammaGrad); 765 766Status SelectGrad(const Scope& scope, const Operation& op, 767 const std::vector<Output>& grad_inputs, 768 std::vector<Output>* grad_outputs) { 769 auto comparator = op.input(0); 770 auto x = op.input(1); 771 auto zeros = ZerosLike(scope, x); 772 auto grad = grad_inputs[0]; 773 774 auto gx_1 = Where3(scope, comparator, grad, zeros); 775 auto gx_2 = Where3(scope, comparator, zeros, grad); 776 777 grad_outputs->push_back(NoGradient()); 778 grad_outputs->push_back(gx_1); 779 grad_outputs->push_back(gx_2); 780 return scope.status(); 781} 782REGISTER_GRADIENT_OP("Select", SelectGrad); 783 784Status MinOrMaxGrad(const Scope& scope, const Operation& op, 785 const std::vector<Output>& grad_inputs, 786 std::vector<Output>* grad_outputs) { 787 // The partial derivative for any input along a "reduced" dimension 788 // is 1 when it is the min (or max) and 0 everywhere else. So the 789 // gradient calculation is identical for both operators. 790 // 791 // There's a special case for propagating gradients when there are 792 // multiple minima (or maxima) - we choose to divide the gradient 793 // equally among all matching inputs. 794 // 795 // Please note this comment 796 // https://github.com/tensorflow/tensorflow/issues/4886#issuecomment-256836063 797 // for details. 798 799 // Running example: 800 // input: [[5, 5, 5], 801 // [1, 2, -3]] 802 // reduction_indices: [1] 803 auto input = op.input(0); 804 auto reduction_indices = op.input(1); 805 806 // [2, 3] 807 auto input_shape = Shape(scope, input); 808 809 // [2, 1] 810 auto output_shape_kept_dims = 811 ReducedShapeHelper(scope, input_shape, reduction_indices); 812 813 // for op=min (say) 814 // output = [5, -3] 815 // y = [[5], 816 // [-3]] 817 auto y = Reshape(scope, op.output(0), output_shape_kept_dims); 818 819 // reshape([g1, g2], [2, 1]) = [[g1], 820 // [g2]] 821 auto grad = Reshape(scope, grad_inputs[0], output_shape_kept_dims); 822 823 // indicators = equal(y, input) 824 // = equal([[5], [[5, 5, 5], 825 // [-3]], [1, 2, -3]]) 826 // = [[1, 1, 1], 827 // [0, 0, 1]] 828 auto indicators = Cast(scope, Equal(scope, y, input), grad_inputs[0].type()); 829 830 // [[3], 831 // [1]] 832 auto num_selected = Reshape(scope, Sum(scope, indicators, reduction_indices), 833 output_shape_kept_dims); 834 835 // [[1/3, 1/3, 1/3], 836 // [0, 0, 1]] 837 auto scale = Div(scope, indicators, num_selected); 838 839 // [[g1/3, g1/3, g1/3], 840 // [0, 0, g2]] 841 grad_outputs->push_back(Mul(scope, scale, grad)); 842 843 // Stop propagation along reduction_indices 844 grad_outputs->push_back(NoGradient()); 845 return scope.status(); 846} 847REGISTER_GRADIENT_OP("Min", MinOrMaxGrad); 848REGISTER_GRADIENT_OP("Max", MinOrMaxGrad); 849 850Status ProdGrad(const Scope& scope, const Operation& op, 851 const std::vector<Output>& grad_inputs, 852 std::vector<Output>* grad_outputs) { 853 auto zero = Const(scope, 0); 854 auto one = Const(scope, 1); 855 856 // The gradient can be expressed by dividing the product by each entry of 857 // the input tensor. If our input is 858 // [ 859 // [3, 4], 860 // [5, 6], 861 // [7, 8] 862 // ] 863 // and we do a Prod operation on the axis 1, we will obtain [[105, 192]]. 864 // The gradient will have the same shape as the input 865 // [ 866 // [105/3, 192/4], 867 // dz * [105/5, 192/6], 868 // [105/7, 192/6] 869 // ] 870 // If the input contains a zero, the division is impossible but 871 // if we take the calculation that gave the first gradient 872 // (3 * 5 * 6)/3 is equal to 5 * 6 873 // the trick will be to cumprod the elements on the axis without 874 // the element at the current position (3 in the example above). 875 // We will take as example: 876 // [ 877 // [ 878 // [3.0, 4.0], 879 // [5.0, 6.0], 880 // [7.0, 8.0] 881 // ], 882 // [ 883 // [3.0, 5.0], 884 // [0.0, 6.0], 885 // [5.0, 6.0] 886 // ] 887 // ] 888 889 // [2, 3, 2] 890 auto input_shape = Shape(scope, op.input(0)); 891 892 // The Reshape with -1 flattens the reduction indices. 893 // [1] 894 auto reduction_indices = Reshape(scope, op.input(1), {-1}); 895 896 // [2, 1, 2] 897 auto output_shape_kept_dims = 898 ReducedShapeHelper(scope, input_shape, reduction_indices); 899 900 // [1, 3, 1] 901 auto tile_scaling = SafeDivHelper(scope, input_shape, output_shape_kept_dims); 902 903 // [[[105, 192]], [[0, 180]]] 904 auto grad = Reshape(scope, grad_inputs[0], output_shape_kept_dims); 905 906 // [[[105, 192], [105, 192], [105, 192]], [[0, 180], [0, 180], [0, 180]]] 907 auto grad_tiled = Tile(scope, grad, tile_scaling); 908 909 Scope cpu_scope = scope.WithDevice("/cpu:0"); 910 911 // [3] 912 auto rank = Rank(cpu_scope, op.input(0)); 913 914 915 // Normalize any negative indices in the reduction_axes to positive values. 916 auto reduction_indices_pos = Mod(cpu_scope, Add(cpu_scope, reduction_indices, rank), rank); 917 918 // [1] 919 auto reduced = Cast(cpu_scope, reduction_indices_pos, DataType::DT_INT32); 920 921 // [0, 1, 2] 922 auto idx = Range(cpu_scope, zero, rank, one); 923 924 // [0, 2] 925 auto other = SetDiff1D(cpu_scope, idx, reduced).out; 926 927 // [1, 0, 2] 928 auto perm = 929 Concat(cpu_scope, std::initializer_list<Input>{reduced, other}, 0); 930 931 // 3 => [3] 932 auto reduced_num = Prod(cpu_scope, Gather(scope, input_shape, reduced), 0); 933 934 // 2 * 2 => [2] 935 auto other_num = Prod(cpu_scope, Gather(scope, input_shape, other), 0); 936 937 // [ 938 // [ 939 // [ 3., 4.], 940 // [ 3., 5.] 941 // ], 942 // [ 943 // [ 5., 6.], 944 // [ 0., 6.] 945 // ], 946 // [ 947 // [ 7., 8.], 948 // [ 5., 6.] 949 // ] 950 // ] 951 auto permuted = Transpose(scope, op.input(0), perm); 952 953 // [3, 2, 2] 954 auto permuted_shape = Shape(scope, permuted); 955 956 // [ 957 // [ 3., 4., 3., 5.], 958 // [ 5., 6., 0., 6.], 959 // [ 7., 8., 5., 6.] 960 // ] 961 auto reshaped = Reshape( 962 scope, permuted, 963 Stack(scope, std::initializer_list<Input>{reduced_num, other_num})); 964 965 // [ 966 // [ 1., 1., 1., 1.], 967 // [ 3., 4., 3., 5.], 968 // [ 15., 24., 0., 30.] 969 // ] 970 auto left = Cumprod(scope, reshaped, zero, Cumprod::Exclusive(true)); 971 972 // [ 973 // [ 35., 48., 0., 36.], 974 // [ 7., 8., 5., 6.], 975 // [ 1., 1., 1., 1.] 976 // ] 977 auto right = 978 Cumprod(scope, reshaped, zero, Cumprod::Exclusive(true).Reverse(true)); 979 980 // left * right = 981 // [ 982 // [ 35., 48., 0., 36.], 983 // [ 21., 32., 15., 30.], 984 // [ 15., 24., 0., 30.] 985 // ] 986 // y = 987 // [ 988 // [ 989 // [ 35., 48.], 990 // [ 0., 36.] 991 // ], 992 // [ 993 // [ 21., 32.], 994 // [ 15., 30.] 995 // ], 996 // [ 997 // [ 15., 24.], 998 // [ 0., 30.] 999 // ] 1000 // ] 1001 auto y = Reshape(scope, Mul(scope, left, right), permuted_shape); 1002 1003 // out = 1004 // [ 1005 // [ 1006 // [ 35., 48.], 1007 // [ 21., 32.], 1008 // [ 15., 24.] 1009 // ], 1010 // [ 1011 // [ 0., 36.], 1012 // [ 15., 30.], 1013 // [ 0., 30.] 1014 // ] 1015 // ] 1016 auto out = 1017 Mul(scope, grad_tiled, Transpose(scope, y, InvertPermutation(scope, perm))); 1018 1019 grad_outputs->push_back(Reshape(scope, out, input_shape)); 1020 1021 // stop propagation along reduction_indices 1022 grad_outputs->push_back(NoGradient()); 1023 return scope.status(); 1024} 1025REGISTER_GRADIENT_OP("Prod", ProdGrad); 1026 1027// MatMulGrad helper function used to compute two MatMul operations 1028// based on input matrix transposition combinations. 1029Status MatMulGradHelper(const Scope& scope, const bool is_batch, 1030 const Output& x0, const bool adj_x0, const Output& x1, 1031 const bool adj_x1, const Output& y0, const bool adj_y0, 1032 const Output& y1, const bool adj_y1, 1033 std::vector<Output>* grad_outputs) { 1034 if (is_batch == false) { 1035 auto dx = 1036 MatMul(scope, x0, x1, MatMul::TransposeA(adj_x0).TransposeB(adj_x1)); 1037 grad_outputs->push_back(dx); 1038 auto dy = 1039 MatMul(scope, y0, y1, MatMul::TransposeA(adj_y0).TransposeB(adj_y1)); 1040 grad_outputs->push_back(dy); 1041 } else { 1042 auto dx = 1043 BatchMatMul(scope, x0, x1, BatchMatMul::AdjX(adj_x0).AdjY(adj_x1)); 1044 grad_outputs->push_back(dx); 1045 auto dy = 1046 BatchMatMul(scope, y0, y1, BatchMatMul::AdjX(adj_y0).AdjY(adj_y1)); 1047 grad_outputs->push_back(dy); 1048 } 1049 return scope.status(); 1050} 1051 1052// MatMulGrad common used to read and check node attr state, and determine 1053// proper MatMul products for gradients based on input matrix transposition 1054// combinations. 1055Status MatMulGradCommon(const Scope& scope, const Operation& op, 1056 const bool is_batch, 1057 const std::vector<Output>& grad_inputs, 1058 const string& attr_adj_x, const string& attr_adj_y, 1059 std::vector<Output>* grad_outputs) { 1060 auto a = op.input(0); 1061 auto b = op.input(1); 1062 // Use conjugate of the inputs for MatMul 1063 if (is_batch == false) { 1064 a = ConjugateHelper(scope, a); 1065 b = ConjugateHelper(scope, b); 1066 } 1067 auto product = op.output(0); 1068 1069 bool ta; 1070 bool tb; 1071 TF_RETURN_IF_ERROR(GetNodeAttr(product.node()->attrs(), attr_adj_x, &ta)); 1072 TF_RETURN_IF_ERROR(GetNodeAttr(product.node()->attrs(), attr_adj_y, &tb)); 1073 1074 if (!ta && !tb) { 1075 return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, b, true, a, 1076 true, grad_inputs[0], false, grad_outputs); 1077 } else if (!ta && tb) { 1078 return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, b, false, 1079 grad_inputs[0], true, a, false, grad_outputs); 1080 } else if (ta && !tb) { 1081 return MatMulGradHelper(scope, is_batch, b, false, grad_inputs[0], true, a, 1082 false, grad_inputs[0], false, grad_outputs); 1083 } 1084 return MatMulGradHelper(scope, is_batch, b, true, grad_inputs[0], true, 1085 grad_inputs[0], true, a, true, grad_outputs); 1086} 1087 1088Status MatMulGrad(const Scope& scope, const Operation& op, 1089 const std::vector<Output>& grad_inputs, 1090 std::vector<Output>* grad_outputs) { 1091 return MatMulGradCommon(scope, op, false, grad_inputs, "transpose_a", 1092 "transpose_b", grad_outputs); 1093} 1094REGISTER_GRADIENT_OP("MatMul", MatMulGrad); 1095 1096Status BatchMatMulGrad(const Scope& scope, const Operation& op, 1097 const std::vector<Output>& grad_inputs, 1098 std::vector<Output>* grad_outputs) { 1099 return MatMulGradCommon(scope, op, true, grad_inputs, "adj_x", "adj_y", 1100 grad_outputs); 1101} 1102REGISTER_GRADIENT_OP("BatchMatMul", BatchMatMulGrad); 1103 1104} // anonymous namespace 1105} // namespace ops 1106} // namespace tensorflow 1107