math_grad.cc revision 16001fc526831c7a7f1a3814f517b01008df4c4c
1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include "tensorflow/cc/ops/array_ops_internal.h" 17#include "tensorflow/cc/ops/standard_ops.h" 18 19#include "tensorflow/cc/framework/grad_op_registry.h" 20 21namespace tensorflow { 22namespace ops { 23namespace { 24 25// Logical operations have no gradients. 26REGISTER_NO_GRADIENT_OP("Less"); 27REGISTER_NO_GRADIENT_OP("LessEqual"); 28REGISTER_NO_GRADIENT_OP("Greater"); 29REGISTER_NO_GRADIENT_OP("GreaterEqual"); 30REGISTER_NO_GRADIENT_OP("Equal"); 31REGISTER_NO_GRADIENT_OP("ApproximateEqual"); 32REGISTER_NO_GRADIENT_OP("NotEqual"); 33REGISTER_NO_GRADIENT_OP("LogicalAnd"); 34REGISTER_NO_GRADIENT_OP("LogicalOr"); 35REGISTER_NO_GRADIENT_OP("LogicalNot"); 36 37// Conjugate helper function returns the conjugate of an Output if it 38// is complex valued. 39Output ConjugateHelper(const Scope& scope, const Output& out) { 40 DataType dtype = out.type(); 41 if (dtype == DT_COMPLEX64 || dtype == DT_COMPLEX128) { 42 return Conj(scope, out); 43 } else { 44 return out; 45 } 46} 47 48// TODO(andydavis) Add control dependencies to gradient functions (as needed). 49 50Status AbsGrad(const Scope& scope, const Operation& op, 51 const std::vector<Output>& grad_inputs, 52 std::vector<Output>* grad_outputs) { 53 // dx = dy * sign(x) 54 grad_outputs->push_back(Mul(scope, grad_inputs[0], Sign(scope, op.input(0)))); 55 return scope.status(); 56} 57REGISTER_GRADIENT_OP("Abs", AbsGrad); 58 59Status NegGrad(const Scope& scope, const Operation& op, 60 const std::vector<Output>& grad_inputs, 61 std::vector<Output>* grad_outputs) { 62 // dx = -dy; 63 grad_outputs->push_back(Neg(scope, grad_inputs[0])); 64 return scope.status(); 65} 66REGISTER_GRADIENT_OP("Neg", NegGrad); 67 68Status InvGrad(const Scope& scope, const Operation& op, 69 const std::vector<Output>& grad_inputs, 70 std::vector<Output>* grad_outputs) { 71 // dy/dx = -1/x^2 = -y^2 72 auto dydx = Neg(scope, Square(scope, op.output(0))); 73 // grad(x) = grad(y) * conj(dy/dx) 74 grad_outputs->push_back( 75 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 76 return scope.status(); 77} 78REGISTER_GRADIENT_OP("Inv", InvGrad); 79REGISTER_GRADIENT_OP("Reciprocal", InvGrad); 80 81Status SquareGrad(const Scope& scope, const Operation& op, 82 const std::vector<Output>& grad_inputs, 83 std::vector<Output>* grad_outputs) { 84 // dy/dx = (2 * x) 85 auto two = Cast(scope, Const(scope, 2), op.input(0).type()); 86 auto dydx = Mul(scope, two, op.input(0)); 87 // grad(x) = grad(y) * conj(dy/dx) 88 grad_outputs->push_back( 89 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 90 return scope.status(); 91} 92REGISTER_GRADIENT_OP("Square", SquareGrad); 93 94Status SqrtGrad(const Scope& scope, const Operation& op, 95 const std::vector<Output>& grad_inputs, 96 std::vector<Output>* grad_outputs) { 97 // y = sqrt(x) 98 // dy/dx = 0.5 * (1 / sqrt(x)) = 0.5 * (1 / y) 99 auto y_inv = Reciprocal(scope, op.output(0)); 100 auto half = Cast(scope, Const(scope, 0.5), op.input(0).type()); 101 auto dydx = Mul(scope, half, y_inv); 102 // grad(x) = grad(y) * conj(dy/dx) 103 grad_outputs->push_back( 104 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 105 return scope.status(); 106} 107REGISTER_GRADIENT_OP("Sqrt", SqrtGrad); 108 109Status RsqrtGrad(const Scope& scope, const Operation& op, 110 const std::vector<Output>& grad_inputs, 111 std::vector<Output>* grad_outputs) { 112 // y = 1/x^1/2 = x^-1/2 113 // dy/dx = -1/2 * x^-3/2 = -1/2 * x^-1/2 * x^-1 = -1/2 * y * x^-1 114 auto x_inv = Reciprocal(scope, op.input(0)); 115 auto y = op.output(0); 116 auto neghalf = Cast(scope, Const(scope, -0.5), op.input(0).type()); 117 auto a = Mul(scope, neghalf, x_inv); 118 auto dydx = Mul(scope, a, y); 119 // grad(x) = grad(y) * conj(dy/dx) 120 grad_outputs->push_back( 121 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 122 return scope.status(); 123} 124REGISTER_GRADIENT_OP("Rsqrt", RsqrtGrad); 125 126Status ExpGrad(const Scope& scope, const Operation& op, 127 const std::vector<Output>& grad_inputs, 128 std::vector<Output>* grad_outputs) { 129 // dy/dx = exp(x) = y 130 // grad(x) = grad(y) * conj(dy/dx) 131 // = grad(y) * conj(y) 132 grad_outputs->push_back( 133 Mul(scope, grad_inputs[0], ConjugateHelper(scope, op.output(0)))); 134 return scope.status(); 135} 136REGISTER_GRADIENT_OP("Exp", ExpGrad); 137 138Status Expm1Grad(const Scope& scope, const Operation& op, 139 const std::vector<Output>& grad_inputs, 140 std::vector<Output>* grad_outputs) { 141 // y = expm1(x) 142 // dy/dx = exp(x) 143 auto dydx = Exp(scope, op.input(0)); 144 // grad(x) = grad(y) * conj(dy/dx) 145 grad_outputs->push_back( 146 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 147 return scope.status(); 148} 149REGISTER_GRADIENT_OP("Expm1", Expm1Grad); 150 151Status LogGrad(const Scope& scope, const Operation& op, 152 const std::vector<Output>& grad_inputs, 153 std::vector<Output>* grad_outputs) { 154 // y = log(x) 155 // dy/dx = 1 / x 156 auto dydx = Reciprocal(scope, op.input(0)); 157 // grad(x) = grad(y) * conj(dy/dx) 158 grad_outputs->push_back( 159 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 160 return scope.status(); 161} 162REGISTER_GRADIENT_OP("Log", LogGrad); 163 164Status Log1pGrad(const Scope& scope, const Operation& op, 165 const std::vector<Output>& grad_inputs, 166 std::vector<Output>* grad_outputs) { 167 // y = log1p(x) 168 // dy/dx = 1 / (1 + x) 169 auto one = Cast(scope, Const(scope, 1.0), op.input(0).type()); 170 auto dydx = Reciprocal(scope, Add(scope, one, op.input(0))); 171 // grad(x) = grad(y) * conj(dy/dx) 172 grad_outputs->push_back( 173 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 174 return scope.status(); 175} 176REGISTER_GRADIENT_OP("Log1p", Log1pGrad); 177 178Status SinhGrad(const Scope& scope, const Operation& op, 179 const std::vector<Output>& grad_inputs, 180 std::vector<Output>* grad_outputs) { 181 // y = sinh(x) 182 // dy/dx = cosh(x) 183 auto dydx = Cosh(scope, op.input(0)); 184 // grad(x) = grad(y) * conj(dy/dx) 185 grad_outputs->push_back( 186 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 187 return scope.status(); 188} 189REGISTER_GRADIENT_OP("Sinh", SinhGrad); 190 191Status CoshGrad(const Scope& scope, const Operation& op, 192 const std::vector<Output>& grad_inputs, 193 std::vector<Output>* grad_outputs) { 194 // y = cosh(x) 195 // dy/dx = sinh(x) 196 auto dydx = Sinh(scope, op.input(0)); 197 // grad(x) = grad(y) * conj(dy/dx) 198 grad_outputs->push_back( 199 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 200 return scope.status(); 201} 202REGISTER_GRADIENT_OP("Cosh", CoshGrad); 203 204Status TanhGrad(const Scope& scope, const Operation& op, 205 const std::vector<Output>& grad_inputs, 206 std::vector<Output>* grad_outputs) { 207 // y = tanh(x) 208 // dy/dx = 1 - (tanh(x))^2 = 1 - y^2 209 auto y2 = Square(scope, op.output(0)); 210 auto one = Cast(scope, Const(scope, 1.0), op.input(0).type()); 211 auto dydx = Sub(scope, one, y2); 212 // grad(x) = grad(y) * conj(dy/dx) 213 grad_outputs->push_back( 214 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 215 return scope.status(); 216} 217REGISTER_GRADIENT_OP("Tanh", TanhGrad); 218 219Status AsinhGrad(const Scope& scope, const Operation& op, 220 const std::vector<Output>& grad_inputs, 221 std::vector<Output>* grad_outputs) { 222 // y = asinh(x) 223 // dy/dx = 1 / cosh(y) 224 auto dydx = Reciprocal(scope, Cosh(scope, op.output(0))); 225 // grad(x) = grad(y) * conj(dy/dx) 226 grad_outputs->push_back( 227 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 228 return scope.status(); 229} 230REGISTER_GRADIENT_OP("Asinh", AsinhGrad); 231 232Status AcoshGrad(const Scope& scope, const Operation& op, 233 const std::vector<Output>& grad_inputs, 234 std::vector<Output>* grad_outputs) { 235 // y = acosh(x) 236 // dy/dx = 1 / sinh(y) 237 auto dydx = Reciprocal(scope, Sinh(scope, op.output(0))); 238 // grad(x) = grad(y) * conj(dy/dx) 239 grad_outputs->push_back( 240 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 241 return scope.status(); 242} 243REGISTER_GRADIENT_OP("Acosh", AcoshGrad); 244 245Status AtanhGrad(const Scope& scope, const Operation& op, 246 const std::vector<Output>& grad_inputs, 247 std::vector<Output>* grad_outputs) { 248 // y = atanh(x) 249 // dy/dx = 1 / (1 - x^2) 250 auto one = Cast(scope, Const(scope, 1.0), op.input(0).type()); 251 auto dydx = Reciprocal(scope, Sub(scope, one, Square(scope, op.input(0)))); 252 // grad(x) = grad(y) * conj(dy/dx) 253 grad_outputs->push_back( 254 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 255 return scope.status(); 256} 257REGISTER_GRADIENT_OP("Atanh", AtanhGrad); 258 259Status SigmoidGrad(const Scope& scope, const Operation& op, 260 const std::vector<Output>& grad_inputs, 261 std::vector<Output>* grad_outputs) { 262 // y = 1 / (1 + exp(-x)) 263 // dy/dx = y * (1 - y) 264 auto y = op.output(0); 265 auto one = Cast(scope, Const(scope, 1.0), op.input(0).type()); 266 auto dydx = Mul(scope, y, Sub(scope, one, y)); 267 // dx = dy * y * (1 - y) 268 // grad(x) = grad(y) * conj(dy/dx) 269 grad_outputs->push_back( 270 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 271 return scope.status(); 272} 273REGISTER_GRADIENT_OP("Sigmoid", SigmoidGrad); 274 275Status SignGrad(const Scope& scope, const Operation& op, 276 const std::vector<Output>& grad_inputs, 277 std::vector<Output>* grad_outputs) { 278 auto shape = Shape(scope, op.input(0)); 279 auto zero = Cast(scope, Const(scope, 0.0), op.input(0).type()); 280 auto dx = Fill(scope, shape, zero); 281 grad_outputs->push_back(dx); 282 return scope.status(); 283} 284REGISTER_GRADIENT_OP("Sign", SignGrad); 285 286Status SinGrad(const Scope& scope, const Operation& op, 287 const std::vector<Output>& grad_inputs, 288 std::vector<Output>* grad_outputs) { 289 // y = sin(x) 290 // dy/dx = cos(x) 291 auto dydx = Cos(scope, op.input(0)); 292 // grad(x) = grad(y) * conj(dy/dx) 293 grad_outputs->push_back( 294 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 295 return scope.status(); 296} 297REGISTER_GRADIENT_OP("Sin", SinGrad); 298 299Status CosGrad(const Scope& scope, const Operation& op, 300 const std::vector<Output>& grad_inputs, 301 std::vector<Output>* grad_outputs) { 302 // y = cos(x) 303 // dy/dx = -sin(x) 304 auto dydx = Neg(scope, Sin(scope, op.input(0))); 305 // grad(x) = grad(y) * conj(dy/dx) 306 grad_outputs->push_back( 307 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 308 return scope.status(); 309} 310REGISTER_GRADIENT_OP("Cos", CosGrad); 311 312Status AsinGrad(const Scope& scope, const Operation& op, 313 const std::vector<Output>& grad_inputs, 314 std::vector<Output>* grad_outputs) { 315 // y = asin(x) 316 // dy/dx = 1 / sqrt(1 - x^2) 317 auto x2 = Square(scope, op.input(0)); 318 auto one = Cast(scope, Const(scope, 1.0), op.input(0).type()); 319 auto dydx = Reciprocal(scope, Sqrt(scope, Sub(scope, one, x2))); 320 // grad(x) = grad(y) * conj(dy/dx) 321 auto dx = Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)); 322 grad_outputs->push_back(dx); 323 return scope.status(); 324} 325REGISTER_GRADIENT_OP("Asin", AsinGrad); 326 327Status AcosGrad(const Scope& scope, const Operation& op, 328 const std::vector<Output>& grad_inputs, 329 std::vector<Output>* grad_outputs) { 330 // y = acos(x) 331 // dy/dx = - 1 / (1 - x * x)^1/2 332 // dx = dy * (- 1 / (1 - x * x)^1/2) 333 auto x2 = Square(scope, op.input(0)); 334 auto one = Cast(scope, Const(scope, 1.0), op.input(0).type()); 335 auto dydx = Neg(scope, Reciprocal(scope, Sqrt(scope, Sub(scope, one, x2)))); 336 auto dx = Mul(scope, grad_inputs[0], dydx); 337 grad_outputs->push_back(dx); 338 return scope.status(); 339} 340REGISTER_GRADIENT_OP("Acos", AcosGrad); 341 342Status TanGrad(const Scope& scope, const Operation& op, 343 const std::vector<Output>& grad_inputs, 344 std::vector<Output>* grad_outputs) { 345 // y = tan(x) 346 // dy/dx = sec(x)^2 = 1 / cos(x)^2 347 auto dydx = Square(scope, Reciprocal(scope, Cos(scope, op.input(0)))); 348 // grad(x) = grad(y) * conj(dy/dx) 349 auto dx = Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)); 350 grad_outputs->push_back(dx); 351 return scope.status(); 352} 353REGISTER_GRADIENT_OP("Tan", TanGrad); 354 355Status AtanGrad(const Scope& scope, const Operation& op, 356 const std::vector<Output>& grad_inputs, 357 std::vector<Output>* grad_outputs) { 358 // y = arctan(x) 359 // dy/dx = 1 / (1 + x^2) 360 // dx = dy * (1 / (1 + x^2) 361 auto one = Cast(scope, Const(scope, 1.0), op.input(0).type()); 362 auto dydx = Reciprocal(scope, Add(scope, one, Square(scope, op.input(0)))); 363 auto dx = Mul(scope, grad_inputs[0], dydx); 364 grad_outputs->push_back(dx); 365 return scope.status(); 366} 367REGISTER_GRADIENT_OP("Atan", AtanGrad); 368 369// BinaryGradCommon handles the setup for binary ops that broadcast 370// their inputs. 371Status BinaryGradCommon(const Scope& scope, const Operation& op, 372 std::vector<Output>* grad_outputs, const Output& gx_1, 373 const Output& gx_2) { 374 auto sx_1 = Shape(scope, op.input(0)); 375 auto sx_2 = Shape(scope, op.input(1)); 376 auto rx = ops::internal::BroadcastGradientArgs(scope, sx_1, sx_2); 377 auto dx_1 = Reshape(scope, Sum(scope, gx_1, rx.r0), sx_1); 378 auto dx_2 = Reshape(scope, Sum(scope, gx_2, rx.r1), sx_2); 379 grad_outputs->push_back(dx_1); 380 grad_outputs->push_back(dx_2); 381 return scope.status(); 382} 383 384Status AddGrad(const Scope& scope, const Operation& op, 385 const std::vector<Output>& grad_inputs, 386 std::vector<Output>* grad_outputs) { 387 // y = x_1 + x_2 388 // dy/dx_1 = dy/dx_2 = 1 389 auto gx_1 = Identity(scope, grad_inputs[0]); 390 auto gx_2 = Identity(scope, grad_inputs[0]); 391 return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2); 392} 393REGISTER_GRADIENT_OP("Add", AddGrad); 394 395Status SubGrad(const Scope& scope, const Operation& op, 396 const std::vector<Output>& grad_inputs, 397 std::vector<Output>* grad_outputs) { 398 // y = x_1 - x_2 399 // dy/dx_1 = 1 400 // dy/dx_2 = -1 401 auto gx_1 = Identity(scope, grad_inputs[0]); 402 auto gx_2 = Neg(scope, grad_inputs[0]); 403 return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2); 404} 405REGISTER_GRADIENT_OP("Sub", SubGrad); 406 407Status MulGrad(const Scope& scope, const Operation& op, 408 const std::vector<Output>& grad_inputs, 409 std::vector<Output>* grad_outputs) { 410 auto x_1 = ConjugateHelper(scope, op.input(0)); 411 auto x_2 = ConjugateHelper(scope, op.input(1)); 412 // y = x_1 * x_2 413 // dy/dx_1 = x_2 414 // dy/dx_2 = x_1 415 auto gx_1 = Mul(scope, grad_inputs[0], x_2); 416 auto gx_2 = Mul(scope, grad_inputs[0], x_1); 417 return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2); 418} 419REGISTER_GRADIENT_OP("Mul", MulGrad); 420 421Status DivGrad(const Scope& scope, const Operation& op, 422 const std::vector<Output>& grad_inputs, 423 std::vector<Output>* grad_outputs) { 424 auto x_1 = ConjugateHelper(scope, op.input(0)); 425 auto x_2 = ConjugateHelper(scope, op.input(1)); 426 // y = x_1 / x_2 427 // dy/dx_1 = 1/x_2 428 // dy/dx_2 = -x_1/x_2^2 429 auto gx_1 = Div(scope, grad_inputs[0], x_2); 430 auto gx_2 = Mul(scope, grad_inputs[0], 431 Div(scope, Div(scope, Neg(scope, x_1), x_2), x_2)); 432 return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2); 433} 434REGISTER_GRADIENT_OP("Div", DivGrad); 435 436Status RealDivGrad(const Scope& scope, const Operation& op, 437 const std::vector<Output>& grad_inputs, 438 std::vector<Output>* grad_outputs) { 439 auto x_1 = ConjugateHelper(scope, op.input(0)); 440 auto x_2 = ConjugateHelper(scope, op.input(1)); 441 // y = x_1 / x_2 442 // dy/dx_1 = 1/x_2 443 // dy/dx_2 = -x_1/x_2^2 444 auto gx_1 = RealDiv(scope, grad_inputs[0], x_2); 445 auto gx_2 = Mul(scope, grad_inputs[0], 446 RealDiv(scope, RealDiv(scope, Neg(scope, x_1), x_2), x_2)); 447 return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2); 448} 449REGISTER_GRADIENT_OP("RealDiv", RealDivGrad); 450 451Status SquaredDifferenceGrad(const Scope& scope, const Operation& op, 452 const std::vector<Output>& grad_inputs, 453 std::vector<Output>* grad_outputs) { 454 auto x_1 = ConjugateHelper(scope, op.input(0)); 455 auto x_2 = ConjugateHelper(scope, op.input(1)); 456 // y = (x_1 - x_2)^2 457 // dy/dx_1 = 2 * (x_1 - x_2) 458 // dy/dx_2 = -2 * (x_1 - x_2) 459 auto two = Cast(scope, Const(scope, 2), grad_inputs[0].type()); 460 auto gx_1 = Mul(scope, grad_inputs[0], Mul(scope, two, Sub(scope, x_1, x_2))); 461 auto gx_2 = Neg(scope, gx_1); 462 return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2); 463} 464REGISTER_GRADIENT_OP("SquaredDifference", SquaredDifferenceGrad); 465 466Status AddNGrad(const Scope& scope, const Operation& op, 467 const std::vector<Output>& grad_inputs, 468 std::vector<Output>* grad_outputs) { 469 // AddN doesn't support broadcasting, so all the inputs must be the 470 // same shape. 471 // Note: 472 // dy/dx_k = d(x_1 + x_2 + ... + x_n)/dx_k = 1 for all x_k 473 // hence dx_k = dy for all x_k 474 // So the gradient for AddN just transfers the incoming gradient to 475 // all outgoing gradients. 476 auto incoming = Identity(scope, grad_inputs[0]); 477 for (int32 i = 0; i < op.num_inputs(); ++i) { 478 grad_outputs->push_back(incoming); 479 } 480 return scope.status(); 481} 482REGISTER_GRADIENT_OP("AddN", AddNGrad); 483 484// MaximumMinimumGradCommon adds shared ops to calculate gradients for 485// the binary Maximum and Minimum ops. 486Status MaximumMinimumGradCommon(const Scope& scope, const Operation& op, 487 const std::vector<Output>& grad_inputs, 488 std::vector<Output>* grad_outputs, 489 const Output& comparator) { 490 // comparator is a boolean tensor, with 491 // y = x_1 at points where comparator is true, and x_2 otherwise 492 // Therefore 493 // dy/dx_1 = 1 where comparator is true, and 0 otherwise. 494 // dy/dx_2 = 0 where comparator is true, and 1 otherwise. 495 auto grad = grad_inputs[0]; 496 auto zeros = ZerosLike(scope, grad); 497 auto gx_1 = Where3(scope, comparator, grad, zeros); 498 auto gx_2 = Where3(scope, LogicalNot(scope, comparator), grad, zeros); 499 return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2); 500} 501 502Status MaximumGrad(const Scope& scope, const Operation& op, 503 const std::vector<Output>& grad_inputs, 504 std::vector<Output>* grad_outputs) { 505 auto comparator = GreaterEqual(scope, op.input(0), op.input(1)); 506 return MaximumMinimumGradCommon(scope, op, grad_inputs, grad_outputs, 507 comparator); 508} 509REGISTER_GRADIENT_OP("Maximum", MaximumGrad); 510 511Status MinimumGrad(const Scope& scope, const Operation& op, 512 const std::vector<Output>& grad_inputs, 513 std::vector<Output>* grad_outputs) { 514 auto comparator = LessEqual(scope, op.input(0), op.input(1)); 515 return MaximumMinimumGradCommon(scope, op, grad_inputs, grad_outputs, 516 comparator); 517} 518REGISTER_GRADIENT_OP("Minimum", MinimumGrad); 519 520Status RealGrad(const Scope& scope, const Operation& op, 521 const std::vector<Output>& grad_inputs, 522 std::vector<Output>* grad_outputs) { 523 auto zero = Cast(scope, Const(scope, 0.0), op.output(0).type()); 524 auto dx = Complex(scope, grad_inputs[0], zero); 525 grad_outputs->push_back(dx); 526 return scope.status(); 527} 528REGISTER_GRADIENT_OP("Real", RealGrad); 529 530Status ImagGrad(const Scope& scope, const Operation& op, 531 const std::vector<Output>& grad_inputs, 532 std::vector<Output>* grad_outputs) { 533 auto zero = Cast(scope, Const(scope, 0.0), op.output(0).type()); 534 auto dx = Complex(scope, zero, grad_inputs[0]); 535 grad_outputs->push_back(dx); 536 return scope.status(); 537} 538REGISTER_GRADIENT_OP("Imag", ImagGrad); 539 540Status ConjGrad(const Scope& scope, const Operation& op, 541 const std::vector<Output>& grad_inputs, 542 std::vector<Output>* grad_outputs) { 543 grad_outputs->push_back(Conj(scope, grad_inputs[0])); 544 return scope.status(); 545} 546REGISTER_GRADIENT_OP("Conj", ConjGrad); 547 548// MatMulGrad helper function used to compute two MatMul operations 549// based on input matrix transposition combinations. 550Status MatMulGradHelper(const Scope& scope, const bool is_batch, 551 const Output& x0, const bool adj_x0, const Output& x1, 552 const bool adj_x1, const Output& y0, const bool adj_y0, 553 const Output& y1, const bool adj_y1, 554 std::vector<Output>* grad_outputs) { 555 if (is_batch == false) { 556 auto dx = 557 MatMul(scope, x0, x1, MatMul::TransposeA(adj_x0).TransposeB(adj_x1)); 558 grad_outputs->push_back(dx); 559 auto dy = 560 MatMul(scope, y0, y1, MatMul::TransposeA(adj_y0).TransposeB(adj_y1)); 561 grad_outputs->push_back(dy); 562 } else { 563 auto dx = 564 BatchMatMul(scope, x0, x1, BatchMatMul::AdjX(adj_x0).AdjY(adj_x1)); 565 grad_outputs->push_back(dx); 566 auto dy = 567 BatchMatMul(scope, y0, y1, BatchMatMul::AdjX(adj_y0).AdjY(adj_y1)); 568 grad_outputs->push_back(dy); 569 } 570 return scope.status(); 571} 572 573// MatMulGrad common used to read and check node attr state, and determine 574// proper MatMul products for gradients based on input matrix transposition 575// combinations. 576// TODO(andydavis) Re-use this function for BatchMatMulGrad. 577Status MatMulGradCommon(const Scope& scope, const Operation& op, 578 const bool is_batch, 579 const std::vector<Output>& grad_inputs, 580 const string& attr_adj_x, const string& attr_adj_y, 581 std::vector<Output>* grad_outputs) { 582 DataType dtype; 583 TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->attrs(), "T", &dtype)); 584 if (dtype == DT_COMPLEX64 || dtype == DT_COMPLEX128) { 585 return errors::Unimplemented( 586 "MatMul gradient for complex data type is not supported yet."); 587 } 588 589 bool ta; 590 bool tb; 591 TF_RETURN_IF_ERROR( 592 GetNodeAttr(op.output(0).node()->attrs(), attr_adj_x, &ta)); 593 TF_RETURN_IF_ERROR( 594 GetNodeAttr(op.output(0).node()->attrs(), attr_adj_y, &tb)); 595 596 if (!ta && !tb) { 597 return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, op.input(1), 598 true, op.input(0), true, grad_inputs[0], false, 599 grad_outputs); 600 } else if (!ta && tb) { 601 return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, op.input(1), 602 false, grad_inputs[0], true, op.input(0), false, 603 grad_outputs); 604 } else if (ta && !tb) { 605 return MatMulGradHelper(scope, is_batch, op.input(1), false, grad_inputs[0], 606 true, op.input(0), false, grad_inputs[0], false, 607 grad_outputs); 608 } 609 return MatMulGradHelper(scope, is_batch, op.input(1), true, grad_inputs[0], 610 true, grad_inputs[0], true, op.input(0), true, 611 grad_outputs); 612} 613 614Status MatMulGrad(const Scope& scope, const Operation& op, 615 const std::vector<Output>& grad_inputs, 616 std::vector<Output>* grad_outputs) { 617 return MatMulGradCommon(scope, op, false, grad_inputs, "transpose_a", 618 "transpose_b", grad_outputs); 619} 620REGISTER_GRADIENT_OP("MatMul", MatMulGrad); 621 622Status BatchMatMulGrad(const Scope& scope, const Operation& op, 623 const std::vector<Output>& grad_inputs, 624 std::vector<Output>* grad_outputs) { 625 return MatMulGradCommon(scope, op, true, grad_inputs, "adj_x", "adj_y", 626 grad_outputs); 627} 628REGISTER_GRADIENT_OP("BatchMatMul", BatchMatMulGrad); 629 630} // anonymous namespace 631} // namespace ops 632} // namespace tensorflow 633