math_grad.cc revision 8624ecc9e827a40f9b514ff6b8ed925390ca79cc
1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include "tensorflow/cc/ops/array_ops_internal.h" 17#include "tensorflow/cc/ops/math_ops_internal.h" 18#include "tensorflow/cc/ops/standard_ops.h" 19 20#include "tensorflow/cc/framework/grad_op_registry.h" 21 22namespace tensorflow { 23namespace ops { 24namespace { 25 26// Logical operations have no gradients. 27REGISTER_NO_GRADIENT_OP("Less"); 28REGISTER_NO_GRADIENT_OP("LessEqual"); 29REGISTER_NO_GRADIENT_OP("Greater"); 30REGISTER_NO_GRADIENT_OP("GreaterEqual"); 31REGISTER_NO_GRADIENT_OP("Equal"); 32REGISTER_NO_GRADIENT_OP("ApproximateEqual"); 33REGISTER_NO_GRADIENT_OP("NotEqual"); 34REGISTER_NO_GRADIENT_OP("LogicalAnd"); 35REGISTER_NO_GRADIENT_OP("LogicalOr"); 36REGISTER_NO_GRADIENT_OP("LogicalNot"); 37 38// Conjugate helper function returns the conjugate of an Output if it 39// is complex valued. 40Output ConjugateHelper(const Scope& scope, const Output& out) { 41 DataType dtype = out.type(); 42 if (dtype == DT_COMPLEX64 || dtype == DT_COMPLEX128) { 43 return Conj(scope, out); 44 } else { 45 return out; 46 } 47} 48 49// TODO(andydavis) Add control dependencies to gradient functions (as needed). 50 51Status AbsGrad(const Scope& scope, const Operation& op, 52 const std::vector<Output>& grad_inputs, 53 std::vector<Output>* grad_outputs) { 54 // dx = dy * sign(x) 55 grad_outputs->push_back(Mul(scope, grad_inputs[0], Sign(scope, op.input(0)))); 56 return scope.status(); 57} 58REGISTER_GRADIENT_OP("Abs", AbsGrad); 59 60Status NegGrad(const Scope& scope, const Operation& op, 61 const std::vector<Output>& grad_inputs, 62 std::vector<Output>* grad_outputs) { 63 // dx = -dy; 64 grad_outputs->push_back(Neg(scope, grad_inputs[0])); 65 return scope.status(); 66} 67REGISTER_GRADIENT_OP("Neg", NegGrad); 68 69Status InvGrad(const Scope& scope, const Operation& op, 70 const std::vector<Output>& grad_inputs, 71 std::vector<Output>* grad_outputs) { 72 // Use the built-in operator. 73 grad_outputs->push_back( 74 internal::ReciprocalGrad(scope, op.output(0), grad_inputs[0])); 75 return scope.status(); 76} 77REGISTER_GRADIENT_OP("Inv", InvGrad); 78REGISTER_GRADIENT_OP("Reciprocal", InvGrad); 79 80Status SquareGrad(const Scope& scope, const Operation& op, 81 const std::vector<Output>& grad_inputs, 82 std::vector<Output>* grad_outputs) { 83 // dy/dx = (2 * x) 84 auto two = Cast(scope, Const(scope, 2), op.input(0).type()); 85 auto dydx = Mul(scope, two, op.input(0)); 86 // grad(x) = grad(y) * conj(dy/dx) 87 grad_outputs->push_back( 88 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 89 return scope.status(); 90} 91REGISTER_GRADIENT_OP("Square", SquareGrad); 92 93Status SqrtGrad(const Scope& scope, const Operation& op, 94 const std::vector<Output>& grad_inputs, 95 std::vector<Output>* grad_outputs) { 96 // Use the built-in operator. 97 grad_outputs->push_back( 98 internal::SqrtGrad(scope, op.output(0), grad_inputs[0])); 99 return scope.status(); 100} 101REGISTER_GRADIENT_OP("Sqrt", SqrtGrad); 102 103Status RsqrtGrad(const Scope& scope, const Operation& op, 104 const std::vector<Output>& grad_inputs, 105 std::vector<Output>* grad_outputs) { 106 // Use the built-in operator. 107 grad_outputs->push_back( 108 internal::RsqrtGrad(scope, op.output(0), grad_inputs[0])); 109 return scope.status(); 110} 111REGISTER_GRADIENT_OP("Rsqrt", RsqrtGrad); 112 113Status ExpGrad(const Scope& scope, const Operation& op, 114 const std::vector<Output>& grad_inputs, 115 std::vector<Output>* grad_outputs) { 116 // dy/dx = exp(x) = y 117 // grad(x) = grad(y) * conj(dy/dx) 118 // = grad(y) * conj(y) 119 grad_outputs->push_back( 120 Mul(scope, grad_inputs[0], ConjugateHelper(scope, op.output(0)))); 121 return scope.status(); 122} 123REGISTER_GRADIENT_OP("Exp", ExpGrad); 124 125Status Expm1Grad(const Scope& scope, const Operation& op, 126 const std::vector<Output>& grad_inputs, 127 std::vector<Output>* grad_outputs) { 128 // y = expm1(x) 129 // dy/dx = exp(x) 130 auto dydx = Exp(scope, op.input(0)); 131 // grad(x) = grad(y) * conj(dy/dx) 132 grad_outputs->push_back( 133 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 134 return scope.status(); 135} 136REGISTER_GRADIENT_OP("Expm1", Expm1Grad); 137 138Status LogGrad(const Scope& scope, const Operation& op, 139 const std::vector<Output>& grad_inputs, 140 std::vector<Output>* grad_outputs) { 141 // y = log(x) 142 // dy/dx = 1 / x 143 auto dydx = Reciprocal(scope, op.input(0)); 144 // grad(x) = grad(y) * conj(dy/dx) 145 grad_outputs->push_back( 146 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 147 return scope.status(); 148} 149REGISTER_GRADIENT_OP("Log", LogGrad); 150 151Status Log1pGrad(const Scope& scope, const Operation& op, 152 const std::vector<Output>& grad_inputs, 153 std::vector<Output>* grad_outputs) { 154 // y = log1p(x) 155 // dy/dx = 1 / (1 + x) 156 auto one = Cast(scope, Const(scope, 1.0), op.input(0).type()); 157 auto dydx = Reciprocal(scope, Add(scope, one, op.input(0))); 158 // grad(x) = grad(y) * conj(dy/dx) 159 grad_outputs->push_back( 160 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 161 return scope.status(); 162} 163REGISTER_GRADIENT_OP("Log1p", Log1pGrad); 164 165Status SinhGrad(const Scope& scope, const Operation& op, 166 const std::vector<Output>& grad_inputs, 167 std::vector<Output>* grad_outputs) { 168 // y = sinh(x) 169 // dy/dx = cosh(x) 170 auto dydx = Cosh(scope, op.input(0)); 171 // grad(x) = grad(y) * conj(dy/dx) 172 grad_outputs->push_back( 173 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 174 return scope.status(); 175} 176REGISTER_GRADIENT_OP("Sinh", SinhGrad); 177 178Status CoshGrad(const Scope& scope, const Operation& op, 179 const std::vector<Output>& grad_inputs, 180 std::vector<Output>* grad_outputs) { 181 // y = cosh(x) 182 // dy/dx = sinh(x) 183 auto dydx = Sinh(scope, op.input(0)); 184 // grad(x) = grad(y) * conj(dy/dx) 185 grad_outputs->push_back( 186 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 187 return scope.status(); 188} 189REGISTER_GRADIENT_OP("Cosh", CoshGrad); 190 191Status TanhGrad(const Scope& scope, const Operation& op, 192 const std::vector<Output>& grad_inputs, 193 std::vector<Output>* grad_outputs) { 194 // Use the built-in operator. 195 // Note that the built-in operator does not return the conjugate of 196 // the gradient. 197 auto grad = grad_inputs[0]; 198 // Optimization to avoid calculating conj(y) until the gradient is 199 // evaluated. 200 Scope grad_scope = scope.WithControlDependencies(grad); 201 auto y = ConjugateHelper(grad_scope, op.output(0)); 202 grad_outputs->push_back(internal::TanhGrad(scope, y, grad)); 203 return scope.status(); 204} 205REGISTER_GRADIENT_OP("Tanh", TanhGrad); 206 207Status AsinhGrad(const Scope& scope, const Operation& op, 208 const std::vector<Output>& grad_inputs, 209 std::vector<Output>* grad_outputs) { 210 // y = asinh(x) 211 // dy/dx = 1 / cosh(y) 212 auto dydx = Reciprocal(scope, Cosh(scope, op.output(0))); 213 // grad(x) = grad(y) * conj(dy/dx) 214 grad_outputs->push_back( 215 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 216 return scope.status(); 217} 218REGISTER_GRADIENT_OP("Asinh", AsinhGrad); 219 220Status AcoshGrad(const Scope& scope, const Operation& op, 221 const std::vector<Output>& grad_inputs, 222 std::vector<Output>* grad_outputs) { 223 // y = acosh(x) 224 // dy/dx = 1 / sinh(y) 225 auto dydx = Reciprocal(scope, Sinh(scope, op.output(0))); 226 // grad(x) = grad(y) * conj(dy/dx) 227 grad_outputs->push_back( 228 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 229 return scope.status(); 230} 231REGISTER_GRADIENT_OP("Acosh", AcoshGrad); 232 233Status AtanhGrad(const Scope& scope, const Operation& op, 234 const std::vector<Output>& grad_inputs, 235 std::vector<Output>* grad_outputs) { 236 // y = atanh(x) 237 // dy/dx = 1 / (1 - x^2) 238 auto one = Cast(scope, Const(scope, 1.0), op.input(0).type()); 239 auto dydx = Reciprocal(scope, Sub(scope, one, Square(scope, op.input(0)))); 240 // grad(x) = grad(y) * conj(dy/dx) 241 grad_outputs->push_back( 242 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 243 return scope.status(); 244} 245REGISTER_GRADIENT_OP("Atanh", AtanhGrad); 246 247Status SigmoidGrad(const Scope& scope, const Operation& op, 248 const std::vector<Output>& grad_inputs, 249 std::vector<Output>* grad_outputs) { 250 // Use the built-in operator. 251 // Note that the built-in operator does not return the conjugate of 252 // the gradient. 253 auto grad = grad_inputs[0]; 254 // Optimization to avoid calculating conj(y) until the gradient is 255 // evaluated. 256 Scope grad_scope = scope.WithControlDependencies(grad); 257 auto y = ConjugateHelper(grad_scope, op.output(0)); 258 grad_outputs->push_back(internal::SigmoidGrad(scope, y, grad)); 259 return scope.status(); 260} 261REGISTER_GRADIENT_OP("Sigmoid", SigmoidGrad); 262 263Status SignGrad(const Scope& scope, const Operation& op, 264 const std::vector<Output>& grad_inputs, 265 std::vector<Output>* grad_outputs) { 266 auto shape = Shape(scope, op.input(0)); 267 auto zero = Cast(scope, Const(scope, 0.0), op.input(0).type()); 268 auto dx = Fill(scope, shape, zero); 269 grad_outputs->push_back(dx); 270 return scope.status(); 271} 272REGISTER_GRADIENT_OP("Sign", SignGrad); 273 274Status SinGrad(const Scope& scope, const Operation& op, 275 const std::vector<Output>& grad_inputs, 276 std::vector<Output>* grad_outputs) { 277 // y = sin(x) 278 // dy/dx = cos(x) 279 auto dydx = Cos(scope, op.input(0)); 280 // grad(x) = grad(y) * conj(dy/dx) 281 grad_outputs->push_back( 282 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 283 return scope.status(); 284} 285REGISTER_GRADIENT_OP("Sin", SinGrad); 286 287Status CosGrad(const Scope& scope, const Operation& op, 288 const std::vector<Output>& grad_inputs, 289 std::vector<Output>* grad_outputs) { 290 // y = cos(x) 291 // dy/dx = -sin(x) 292 auto dydx = Neg(scope, Sin(scope, op.input(0))); 293 // grad(x) = grad(y) * conj(dy/dx) 294 grad_outputs->push_back( 295 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 296 return scope.status(); 297} 298REGISTER_GRADIENT_OP("Cos", CosGrad); 299 300Status AsinGrad(const Scope& scope, const Operation& op, 301 const std::vector<Output>& grad_inputs, 302 std::vector<Output>* grad_outputs) { 303 // y = asin(x) 304 // dy/dx = 1 / sqrt(1 - x^2) 305 auto x2 = Square(scope, op.input(0)); 306 auto one = Cast(scope, Const(scope, 1.0), op.input(0).type()); 307 auto dydx = Reciprocal(scope, Sqrt(scope, Sub(scope, one, x2))); 308 // grad(x) = grad(y) * conj(dy/dx) 309 auto dx = Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)); 310 grad_outputs->push_back(dx); 311 return scope.status(); 312} 313REGISTER_GRADIENT_OP("Asin", AsinGrad); 314 315Status AcosGrad(const Scope& scope, const Operation& op, 316 const std::vector<Output>& grad_inputs, 317 std::vector<Output>* grad_outputs) { 318 // y = acos(x) 319 // dy/dx = - 1 / (1 - x * x)^1/2 320 // dx = dy * (- 1 / (1 - x * x)^1/2) 321 auto x2 = Square(scope, op.input(0)); 322 auto one = Cast(scope, Const(scope, 1.0), op.input(0).type()); 323 auto dydx = Neg(scope, Reciprocal(scope, Sqrt(scope, Sub(scope, one, x2)))); 324 auto dx = Mul(scope, grad_inputs[0], dydx); 325 grad_outputs->push_back(dx); 326 return scope.status(); 327} 328REGISTER_GRADIENT_OP("Acos", AcosGrad); 329 330Status TanGrad(const Scope& scope, const Operation& op, 331 const std::vector<Output>& grad_inputs, 332 std::vector<Output>* grad_outputs) { 333 // y = tan(x) 334 // dy/dx = sec(x)^2 = 1 / cos(x)^2 335 auto dydx = Square(scope, Reciprocal(scope, Cos(scope, op.input(0)))); 336 // grad(x) = grad(y) * conj(dy/dx) 337 auto dx = Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)); 338 grad_outputs->push_back(dx); 339 return scope.status(); 340} 341REGISTER_GRADIENT_OP("Tan", TanGrad); 342 343Status AtanGrad(const Scope& scope, const Operation& op, 344 const std::vector<Output>& grad_inputs, 345 std::vector<Output>* grad_outputs) { 346 // y = arctan(x) 347 // dy/dx = 1 / (1 + x^2) 348 // dx = dy * (1 / (1 + x^2) 349 auto one = Cast(scope, Const(scope, 1.0), op.input(0).type()); 350 auto dydx = Reciprocal(scope, Add(scope, one, Square(scope, op.input(0)))); 351 auto dx = Mul(scope, grad_inputs[0], dydx); 352 grad_outputs->push_back(dx); 353 return scope.status(); 354} 355REGISTER_GRADIENT_OP("Atan", AtanGrad); 356 357// BinaryGradCommon handles the setup for binary ops that broadcast 358// their inputs. 359Status BinaryGradCommon(const Scope& scope, const Operation& op, 360 std::vector<Output>* grad_outputs, const Output& gx_1, 361 const Output& gx_2) { 362 auto sx_1 = Shape(scope, op.input(0)); 363 auto sx_2 = Shape(scope, op.input(1)); 364 auto rx = internal::BroadcastGradientArgs(scope, sx_1, sx_2); 365 auto dx_1 = Reshape(scope, Sum(scope, gx_1, rx.r0), sx_1); 366 auto dx_2 = Reshape(scope, Sum(scope, gx_2, rx.r1), sx_2); 367 grad_outputs->push_back(dx_1); 368 grad_outputs->push_back(dx_2); 369 return scope.status(); 370} 371 372Status AddGrad(const Scope& scope, const Operation& op, 373 const std::vector<Output>& grad_inputs, 374 std::vector<Output>* grad_outputs) { 375 // y = x_1 + x_2 376 // dy/dx_1 = dy/dx_2 = 1 377 auto gx_1 = Identity(scope, grad_inputs[0]); 378 auto gx_2 = Identity(scope, grad_inputs[0]); 379 return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2); 380} 381REGISTER_GRADIENT_OP("Add", AddGrad); 382 383Status SubGrad(const Scope& scope, const Operation& op, 384 const std::vector<Output>& grad_inputs, 385 std::vector<Output>* grad_outputs) { 386 // y = x_1 - x_2 387 // dy/dx_1 = 1 388 // dy/dx_2 = -1 389 auto gx_1 = Identity(scope, grad_inputs[0]); 390 auto gx_2 = Neg(scope, grad_inputs[0]); 391 return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2); 392} 393REGISTER_GRADIENT_OP("Sub", SubGrad); 394 395Status MulGrad(const Scope& scope, const Operation& op, 396 const std::vector<Output>& grad_inputs, 397 std::vector<Output>* grad_outputs) { 398 auto x_1 = ConjugateHelper(scope, op.input(0)); 399 auto x_2 = ConjugateHelper(scope, op.input(1)); 400 // y = x_1 * x_2 401 // dy/dx_1 = x_2 402 // dy/dx_2 = x_1 403 auto gx_1 = Mul(scope, grad_inputs[0], x_2); 404 auto gx_2 = Mul(scope, grad_inputs[0], x_1); 405 return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2); 406} 407REGISTER_GRADIENT_OP("Mul", MulGrad); 408 409Status DivGrad(const Scope& scope, const Operation& op, 410 const std::vector<Output>& grad_inputs, 411 std::vector<Output>* grad_outputs) { 412 auto x_1 = ConjugateHelper(scope, op.input(0)); 413 auto x_2 = ConjugateHelper(scope, op.input(1)); 414 // y = x_1 / x_2 415 // dy/dx_1 = 1/x_2 416 // dy/dx_2 = -x_1/x_2^2 417 auto gx_1 = Div(scope, grad_inputs[0], x_2); 418 auto gx_2 = Mul(scope, grad_inputs[0], 419 Div(scope, Div(scope, Neg(scope, x_1), x_2), x_2)); 420 return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2); 421} 422REGISTER_GRADIENT_OP("Div", DivGrad); 423 424Status RealDivGrad(const Scope& scope, const Operation& op, 425 const std::vector<Output>& grad_inputs, 426 std::vector<Output>* grad_outputs) { 427 auto x_1 = ConjugateHelper(scope, op.input(0)); 428 auto x_2 = ConjugateHelper(scope, op.input(1)); 429 // y = x_1 / x_2 430 // dy/dx_1 = 1/x_2 431 // dy/dx_2 = -x_1/x_2^2 432 auto gx_1 = RealDiv(scope, grad_inputs[0], x_2); 433 auto gx_2 = Mul(scope, grad_inputs[0], 434 RealDiv(scope, RealDiv(scope, Neg(scope, x_1), x_2), x_2)); 435 return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2); 436} 437REGISTER_GRADIENT_OP("RealDiv", RealDivGrad); 438 439Status SquaredDifferenceGrad(const Scope& scope, const Operation& op, 440 const std::vector<Output>& grad_inputs, 441 std::vector<Output>* grad_outputs) { 442 auto x_1 = ConjugateHelper(scope, op.input(0)); 443 auto x_2 = ConjugateHelper(scope, op.input(1)); 444 // y = (x_1 - x_2)^2 445 // dy/dx_1 = 2 * (x_1 - x_2) 446 // dy/dx_2 = -2 * (x_1 - x_2) 447 auto two = Cast(scope, Const(scope, 2), grad_inputs[0].type()); 448 auto gx_1 = Mul(scope, grad_inputs[0], Mul(scope, two, Sub(scope, x_1, x_2))); 449 auto gx_2 = Neg(scope, gx_1); 450 return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2); 451} 452REGISTER_GRADIENT_OP("SquaredDifference", SquaredDifferenceGrad); 453 454Status AddNGrad(const Scope& scope, const Operation& op, 455 const std::vector<Output>& grad_inputs, 456 std::vector<Output>* grad_outputs) { 457 // AddN doesn't support broadcasting, so all the inputs must be the 458 // same shape. 459 // Note: 460 // dy/dx_k = d(x_1 + x_2 + ... + x_n)/dx_k = 1 for all x_k 461 // hence dx_k = dy for all x_k 462 // So the gradient for AddN just transfers the incoming gradient to 463 // all outgoing gradients. 464 auto incoming = Identity(scope, grad_inputs[0]); 465 for (int32 i = 0; i < op.num_inputs(); ++i) { 466 grad_outputs->push_back(incoming); 467 } 468 return scope.status(); 469} 470REGISTER_GRADIENT_OP("AddN", AddNGrad); 471 472// MaximumMinimumGradCommon adds shared ops to calculate gradients for 473// the binary Maximum and Minimum ops. 474Status MaximumMinimumGradCommon(const Scope& scope, const Operation& op, 475 const std::vector<Output>& grad_inputs, 476 std::vector<Output>* grad_outputs, 477 const Output& comparator) { 478 // comparator is a boolean tensor, with 479 // y = x_1 at points where comparator is true, and x_2 otherwise 480 // Therefore 481 // dy/dx_1 = 1 where comparator is true, and 0 otherwise. 482 // dy/dx_2 = 0 where comparator is true, and 1 otherwise. 483 auto grad = grad_inputs[0]; 484 auto zeros = ZerosLike(scope, grad); 485 auto gx_1 = Where3(scope, comparator, grad, zeros); 486 auto gx_2 = Where3(scope, LogicalNot(scope, comparator), grad, zeros); 487 return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2); 488} 489 490Status MaximumGrad(const Scope& scope, const Operation& op, 491 const std::vector<Output>& grad_inputs, 492 std::vector<Output>* grad_outputs) { 493 auto comparator = GreaterEqual(scope, op.input(0), op.input(1)); 494 return MaximumMinimumGradCommon(scope, op, grad_inputs, grad_outputs, 495 comparator); 496} 497REGISTER_GRADIENT_OP("Maximum", MaximumGrad); 498 499Status MinimumGrad(const Scope& scope, const Operation& op, 500 const std::vector<Output>& grad_inputs, 501 std::vector<Output>* grad_outputs) { 502 auto comparator = LessEqual(scope, op.input(0), op.input(1)); 503 return MaximumMinimumGradCommon(scope, op, grad_inputs, grad_outputs, 504 comparator); 505} 506REGISTER_GRADIENT_OP("Minimum", MinimumGrad); 507 508Status RealGrad(const Scope& scope, const Operation& op, 509 const std::vector<Output>& grad_inputs, 510 std::vector<Output>* grad_outputs) { 511 auto zero = Cast(scope, Const(scope, 0.0), op.output(0).type()); 512 auto dx = Complex(scope, grad_inputs[0], zero); 513 grad_outputs->push_back(dx); 514 return scope.status(); 515} 516REGISTER_GRADIENT_OP("Real", RealGrad); 517 518Status ImagGrad(const Scope& scope, const Operation& op, 519 const std::vector<Output>& grad_inputs, 520 std::vector<Output>* grad_outputs) { 521 auto zero = Cast(scope, Const(scope, 0.0), op.output(0).type()); 522 auto dx = Complex(scope, zero, grad_inputs[0]); 523 grad_outputs->push_back(dx); 524 return scope.status(); 525} 526REGISTER_GRADIENT_OP("Imag", ImagGrad); 527 528Status AngleGrad(const Scope& scope, const Operation& op, 529 const std::vector<Output>& grad_inputs, 530 std::vector<Output>* grad_outputs) { 531 // y = Angle(x) 532 // dx = -dy / (Im(x) + iRe(x)) = -dy * z 533 auto re = Real(scope, op.input(0)); 534 auto im = Imag(scope, op.input(0)); 535 auto z_inv = Reciprocal(scope, Complex(scope, im, re)); 536 auto zero = Cast(scope, Const(scope, 0), grad_inputs[0].type()); 537 auto grad = Complex(scope, grad_inputs[0], zero); 538 auto dx = Neg(scope, Mul(scope, grad, z_inv)); 539 grad_outputs->push_back(dx); 540 return scope.status(); 541} 542REGISTER_GRADIENT_OP("Angle", AngleGrad); 543 544Status ConjGrad(const Scope& scope, const Operation& op, 545 const std::vector<Output>& grad_inputs, 546 std::vector<Output>* grad_outputs) { 547 grad_outputs->push_back(Conj(scope, grad_inputs[0])); 548 return scope.status(); 549} 550REGISTER_GRADIENT_OP("Conj", ConjGrad); 551 552// MatMulGrad helper function used to compute two MatMul operations 553// based on input matrix transposition combinations. 554Status MatMulGradHelper(const Scope& scope, const bool is_batch, 555 const Output& x0, const bool adj_x0, const Output& x1, 556 const bool adj_x1, const Output& y0, const bool adj_y0, 557 const Output& y1, const bool adj_y1, 558 std::vector<Output>* grad_outputs) { 559 if (is_batch == false) { 560 auto dx = 561 MatMul(scope, x0, x1, MatMul::TransposeA(adj_x0).TransposeB(adj_x1)); 562 grad_outputs->push_back(dx); 563 auto dy = 564 MatMul(scope, y0, y1, MatMul::TransposeA(adj_y0).TransposeB(adj_y1)); 565 grad_outputs->push_back(dy); 566 } else { 567 auto dx = 568 BatchMatMul(scope, x0, x1, BatchMatMul::AdjX(adj_x0).AdjY(adj_x1)); 569 grad_outputs->push_back(dx); 570 auto dy = 571 BatchMatMul(scope, y0, y1, BatchMatMul::AdjX(adj_y0).AdjY(adj_y1)); 572 grad_outputs->push_back(dy); 573 } 574 return scope.status(); 575} 576 577// MatMulGrad common used to read and check node attr state, and determine 578// proper MatMul products for gradients based on input matrix transposition 579// combinations. 580// TODO(andydavis) Re-use this function for BatchMatMulGrad. 581Status MatMulGradCommon(const Scope& scope, const Operation& op, 582 const bool is_batch, 583 const std::vector<Output>& grad_inputs, 584 const string& attr_adj_x, const string& attr_adj_y, 585 std::vector<Output>* grad_outputs) { 586 DataType dtype; 587 TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->attrs(), "T", &dtype)); 588 if (dtype == DT_COMPLEX64 || dtype == DT_COMPLEX128) { 589 return errors::Unimplemented( 590 "MatMul gradient for complex data type is not supported yet."); 591 } 592 593 bool ta; 594 bool tb; 595 TF_RETURN_IF_ERROR( 596 GetNodeAttr(op.output(0).node()->attrs(), attr_adj_x, &ta)); 597 TF_RETURN_IF_ERROR( 598 GetNodeAttr(op.output(0).node()->attrs(), attr_adj_y, &tb)); 599 600 if (!ta && !tb) { 601 return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, op.input(1), 602 true, op.input(0), true, grad_inputs[0], false, 603 grad_outputs); 604 } else if (!ta && tb) { 605 return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, op.input(1), 606 false, grad_inputs[0], true, op.input(0), false, 607 grad_outputs); 608 } else if (ta && !tb) { 609 return MatMulGradHelper(scope, is_batch, op.input(1), false, grad_inputs[0], 610 true, op.input(0), false, grad_inputs[0], false, 611 grad_outputs); 612 } 613 return MatMulGradHelper(scope, is_batch, op.input(1), true, grad_inputs[0], 614 true, grad_inputs[0], true, op.input(0), true, 615 grad_outputs); 616} 617 618Status MatMulGrad(const Scope& scope, const Operation& op, 619 const std::vector<Output>& grad_inputs, 620 std::vector<Output>* grad_outputs) { 621 return MatMulGradCommon(scope, op, false, grad_inputs, "transpose_a", 622 "transpose_b", grad_outputs); 623} 624REGISTER_GRADIENT_OP("MatMul", MatMulGrad); 625 626Status BatchMatMulGrad(const Scope& scope, const Operation& op, 627 const std::vector<Output>& grad_inputs, 628 std::vector<Output>* grad_outputs) { 629 return MatMulGradCommon(scope, op, true, grad_inputs, "adj_x", "adj_y", 630 grad_outputs); 631} 632REGISTER_GRADIENT_OP("BatchMatMul", BatchMatMulGrad); 633 634} // anonymous namespace 635} // namespace ops 636} // namespace tensorflow 637