math_grad.cc revision 1fa73c53ab95693f070ce70e6be0c644d83c163a
1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include "tensorflow/cc/ops/standard_ops.h" 17 18#include "tensorflow/cc/framework/grad_op_registry.h" 19 20namespace tensorflow { 21namespace ops { 22namespace { 23 24// Conjugate helper function returns the conjugate of an Output if it 25// is complex valued. 26Output ConjugateHelper(const Scope& scope, const Output& out) { 27 DataType dtype = out.type(); 28 if (dtype == DT_COMPLEX64 || dtype == DT_COMPLEX128) { 29 return Conj(scope, out); 30 } else { 31 return out; 32 } 33} 34 35// TODO(andydavis) Add control dependencies to gradient functions (as needed). 36 37Status AbsGrad(const Scope& scope, const Operation& op, 38 const std::vector<Output>& grad_inputs, 39 std::vector<Output>* grad_outputs) { 40 // dx = dy * sign(x) 41 grad_outputs->push_back(Mul(scope, grad_inputs[0], Sign(scope, op.input(0)))); 42 return scope.status(); 43} 44REGISTER_GRADIENT_OP("Abs", AbsGrad); 45 46Status NegGrad(const Scope& scope, const Operation& op, 47 const std::vector<Output>& grad_inputs, 48 std::vector<Output>* grad_outputs) { 49 // dx = -dy; 50 grad_outputs->push_back(Neg(scope, grad_inputs[0])); 51 return scope.status(); 52} 53REGISTER_GRADIENT_OP("Neg", NegGrad); 54 55Status InvGrad(const Scope& scope, const Operation& op, 56 const std::vector<Output>& grad_inputs, 57 std::vector<Output>* grad_outputs) { 58 // dy/dx = -1/x^2 = -y^2 59 auto dydx = Neg(scope, Square(scope, op.output(0))); 60 // grad(x) = grad(y) * conj(dy/dx) 61 grad_outputs->push_back( 62 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 63 return scope.status(); 64} 65REGISTER_GRADIENT_OP("Inv", InvGrad); 66REGISTER_GRADIENT_OP("Reciprocal", InvGrad); 67 68Status SquareGrad(const Scope& scope, const Operation& op, 69 const std::vector<Output>& grad_inputs, 70 std::vector<Output>* grad_outputs) { 71 // dy/dx = (2 * x) 72 auto two = Cast(scope, Const(scope, 2), op.input(0).type()); 73 auto dydx = Mul(scope, two, op.input(0)); 74 // grad(x) = grad(y) * conj(dy/dx) 75 grad_outputs->push_back( 76 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 77 return scope.status(); 78} 79REGISTER_GRADIENT_OP("Square", SquareGrad); 80 81Status SqrtGrad(const Scope& scope, const Operation& op, 82 const std::vector<Output>& grad_inputs, 83 std::vector<Output>* grad_outputs) { 84 // y = sqrt(x) 85 // dy/dx = 0.5 * (1 / sqrt(x)) = 0.5 * (1 / y) 86 auto y_inv = Reciprocal(scope, op.output(0)); 87 auto half = Cast(scope, Const(scope, 0.5), op.input(0).type()); 88 auto dydx = Mul(scope, half, y_inv); 89 // grad(x) = grad(y) * conj(dy/dx) 90 grad_outputs->push_back( 91 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 92 return scope.status(); 93} 94REGISTER_GRADIENT_OP("Sqrt", SqrtGrad); 95 96Status RsqrtGrad(const Scope& scope, const Operation& op, 97 const std::vector<Output>& grad_inputs, 98 std::vector<Output>* grad_outputs) { 99 // y = 1/x^1/2 = x^-1/2 100 // dy/dx = -1/2 * x^-3/2 = -1/2 * x^-1/2 * x^-1 = -1/2 * y * x^-1 101 auto x_inv = Reciprocal(scope, op.input(0)); 102 auto y = op.output(0); 103 auto neghalf = Cast(scope, Const(scope, -0.5), op.input(0).type()); 104 auto a = Mul(scope, neghalf, x_inv); 105 auto dydx = Mul(scope, a, y); 106 // grad(x) = grad(y) * conj(dy/dx) 107 grad_outputs->push_back( 108 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 109 return scope.status(); 110} 111REGISTER_GRADIENT_OP("Rsqrt", RsqrtGrad); 112 113Status ExpGrad(const Scope& scope, const Operation& op, 114 const std::vector<Output>& grad_inputs, 115 std::vector<Output>* grad_outputs) { 116 // dy/dx = exp(x) = y 117 // grad(x) = grad(y) * conj(dy/dx) 118 // = grad(y) * conj(y) 119 grad_outputs->push_back( 120 Mul(scope, grad_inputs[0], ConjugateHelper(scope, op.output(0)))); 121 return scope.status(); 122} 123REGISTER_GRADIENT_OP("Exp", ExpGrad); 124 125Status Expm1Grad(const Scope& scope, const Operation& op, 126 const std::vector<Output>& grad_inputs, 127 std::vector<Output>* grad_outputs) { 128 // y = expm1(x) 129 // dy/dx = exp(x) 130 auto dydx = Exp(scope, op.input(0)); 131 // grad(x) = grad(y) * conj(dy/dx) 132 grad_outputs->push_back( 133 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 134 return scope.status(); 135} 136REGISTER_GRADIENT_OP("Expm1", Expm1Grad); 137 138Status LogGrad(const Scope& scope, const Operation& op, 139 const std::vector<Output>& grad_inputs, 140 std::vector<Output>* grad_outputs) { 141 // y = log(x) 142 // dy/dx = 1 / x 143 auto dydx = Reciprocal(scope, op.input(0)); 144 // grad(x) = grad(y) * conj(dy/dx) 145 grad_outputs->push_back( 146 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 147 return scope.status(); 148} 149REGISTER_GRADIENT_OP("Log", LogGrad); 150 151Status Log1pGrad(const Scope& scope, const Operation& op, 152 const std::vector<Output>& grad_inputs, 153 std::vector<Output>* grad_outputs) { 154 // y = log1p(x) 155 // dy/dx = 1 / (1 + x) 156 auto one = Cast(scope, Const(scope, 1.0), op.input(0).type()); 157 auto dydx = Reciprocal(scope, Add(scope, one, op.input(0))); 158 // grad(x) = grad(y) * conj(dy/dx) 159 grad_outputs->push_back( 160 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 161 return scope.status(); 162} 163REGISTER_GRADIENT_OP("Log1p", Log1pGrad); 164 165Status TanhGrad(const Scope& scope, const Operation& op, 166 const std::vector<Output>& grad_inputs, 167 std::vector<Output>* grad_outputs) { 168 // y = tanh(x) 169 // dy/dx = 1 - (tanh(x))^2 = 1 - y^2 170 auto y2 = Square(scope, op.output(0)); 171 auto one = Cast(scope, Const(scope, 1.0), op.input(0).type()); 172 auto dydx = Sub(scope, one, y2); 173 // grad(x) = grad(y) * conj(dy/dx) 174 grad_outputs->push_back( 175 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 176 return scope.status(); 177} 178REGISTER_GRADIENT_OP("Tanh", TanhGrad); 179 180Status SigmoidGrad(const Scope& scope, const Operation& op, 181 const std::vector<Output>& grad_inputs, 182 std::vector<Output>* grad_outputs) { 183 // y = 1 / (1 + exp(-x)) 184 // dy/dx = y * (1 - y) 185 auto y = op.output(0); 186 auto one = Cast(scope, Const(scope, 1.0), op.input(0).type()); 187 auto dydx = Mul(scope, y, Sub(scope, one, y)); 188 // dx = dy * y * (1 - y) 189 // grad(x) = grad(y) * conj(dy/dx) 190 grad_outputs->push_back( 191 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 192 return scope.status(); 193} 194REGISTER_GRADIENT_OP("Sigmoid", SigmoidGrad); 195 196Status SignGrad(const Scope& scope, const Operation& op, 197 const std::vector<Output>& grad_inputs, 198 std::vector<Output>* grad_outputs) { 199 auto shape = Shape(scope, op.input(0)); 200 auto zero = Cast(scope, Const(scope, 0.0), op.input(0).type()); 201 auto dx = Fill(scope, shape, zero); 202 grad_outputs->push_back(dx); 203 return scope.status(); 204} 205REGISTER_GRADIENT_OP("Sign", SignGrad); 206 207Status SinGrad(const Scope& scope, const Operation& op, 208 const std::vector<Output>& grad_inputs, 209 std::vector<Output>* grad_outputs) { 210 // y = sin(x) 211 // dy/dx = cos(x) 212 auto dydx = Cos(scope, op.input(0)); 213 // grad(x) = grad(y) * conj(dy/dx) 214 grad_outputs->push_back( 215 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 216 return scope.status(); 217} 218REGISTER_GRADIENT_OP("Sin", SinGrad); 219 220Status CosGrad(const Scope& scope, const Operation& op, 221 const std::vector<Output>& grad_inputs, 222 std::vector<Output>* grad_outputs) { 223 // y = cos(x) 224 // dy/dx = -sin(x) 225 auto dydx = Neg(scope, Sin(scope, op.input(0))); 226 // grad(x) = grad(y) * conj(dy/dx) 227 grad_outputs->push_back( 228 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); 229 return scope.status(); 230} 231REGISTER_GRADIENT_OP("Cos", CosGrad); 232 233Status AsinGrad(const Scope& scope, const Operation& op, 234 const std::vector<Output>& grad_inputs, 235 std::vector<Output>* grad_outputs) { 236 // y = asin(x) 237 // dy/dx = 1 / sqrt(1 - x^2) 238 auto x2 = Square(scope, op.input(0)); 239 auto one = Cast(scope, Const(scope, 1.0), op.input(0).type()); 240 auto dydx = Reciprocal(scope, Sqrt(scope, Sub(scope, one, x2))); 241 // grad(x) = grad(y) * conj(dy/dx) 242 auto dx = Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)); 243 grad_outputs->push_back(dx); 244 return scope.status(); 245} 246REGISTER_GRADIENT_OP("Asin", AsinGrad); 247 248Status AcosGrad(const Scope& scope, const Operation& op, 249 const std::vector<Output>& grad_inputs, 250 std::vector<Output>* grad_outputs) { 251 // y = acos(x) 252 // dy/dx = - 1 / (1 - x * x)^1/2 253 // dx = dy * (- 1 / (1 - x * x)^1/2) 254 auto x2 = Square(scope, op.input(0)); 255 auto one = Cast(scope, Const(scope, 1.0), op.input(0).type()); 256 auto dydx = Neg(scope, Reciprocal(scope, Sqrt(scope, Sub(scope, one, x2)))); 257 auto dx = Mul(scope, grad_inputs[0], dydx); 258 grad_outputs->push_back(dx); 259 return scope.status(); 260} 261REGISTER_GRADIENT_OP("Acos", AcosGrad); 262 263Status TanGrad(const Scope& scope, const Operation& op, 264 const std::vector<Output>& grad_inputs, 265 std::vector<Output>* grad_outputs) { 266 // y = tan(x) 267 // dy/dx = sec(x)^2 = 1 / cos(x)^2 268 auto dydx = Square(scope, Reciprocal(scope, Cos(scope, op.input(0)))); 269 // grad(x) = grad(y) * conj(dy/dx) 270 auto dx = Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)); 271 grad_outputs->push_back(dx); 272 return scope.status(); 273} 274REGISTER_GRADIENT_OP("Tan", TanGrad); 275 276Status AtanGrad(const Scope& scope, const Operation& op, 277 const std::vector<Output>& grad_inputs, 278 std::vector<Output>* grad_outputs) { 279 // y = arctan(x) 280 // dy/dx = 1 / (1 + x^2) 281 // dx = dy * (1 / (1 + x^2) 282 auto one = Cast(scope, Const(scope, 1.0), op.input(0).type()); 283 auto dydx = Reciprocal(scope, Add(scope, one, Square(scope, op.input(0)))); 284 auto dx = Mul(scope, grad_inputs[0], dydx); 285 grad_outputs->push_back(dx); 286 return scope.status(); 287} 288REGISTER_GRADIENT_OP("Atan", AtanGrad); 289 290Status RealGrad(const Scope& scope, const Operation& op, 291 const std::vector<Output>& grad_inputs, 292 std::vector<Output>* grad_outputs) { 293 auto zero = Cast(scope, Const(scope, 0.0), op.output(0).type()); 294 auto dx = Complex(scope, grad_inputs[0], zero); 295 grad_outputs->push_back(dx); 296 return scope.status(); 297} 298REGISTER_GRADIENT_OP("Real", RealGrad); 299 300Status ImagGrad(const Scope& scope, const Operation& op, 301 const std::vector<Output>& grad_inputs, 302 std::vector<Output>* grad_outputs) { 303 auto zero = Cast(scope, Const(scope, 0.0), op.output(0).type()); 304 auto dx = Complex(scope, zero, grad_inputs[0]); 305 grad_outputs->push_back(dx); 306 return scope.status(); 307} 308REGISTER_GRADIENT_OP("Imag", ImagGrad); 309 310Status ConjGrad(const Scope& scope, const Operation& op, 311 const std::vector<Output>& grad_inputs, 312 std::vector<Output>* grad_outputs) { 313 grad_outputs->push_back(Conj(scope, grad_inputs[0])); 314 return scope.status(); 315} 316REGISTER_GRADIENT_OP("Conj", ConjGrad); 317 318// MatMulGrad helper function used to compute two MatMul operations 319// based on input matrix transposition combinations. 320Status MatMulGradHelper(const Scope& scope, const bool is_batch, 321 const Output& x0, const bool adj_x0, const Output& x1, 322 const bool adj_x1, const Output& y0, const bool adj_y0, 323 const Output& y1, const bool adj_y1, 324 std::vector<Output>* grad_outputs) { 325 if (is_batch == false) { 326 auto dx = 327 MatMul(scope, x0, x1, MatMul::TransposeA(adj_x0).TransposeB(adj_x1)); 328 grad_outputs->push_back(dx); 329 auto dy = 330 MatMul(scope, y0, y1, MatMul::TransposeA(adj_y0).TransposeB(adj_y1)); 331 grad_outputs->push_back(dy); 332 } else { 333 auto dx = 334 BatchMatMul(scope, x0, x1, BatchMatMul::AdjX(adj_x0).AdjY(adj_x1)); 335 grad_outputs->push_back(dx); 336 auto dy = 337 BatchMatMul(scope, y0, y1, BatchMatMul::AdjX(adj_y0).AdjY(adj_y1)); 338 grad_outputs->push_back(dy); 339 } 340 return scope.status(); 341} 342 343// MatMulGrad common used to read and check node attr state, and determine 344// proper MatMul products for gradients based on input matrix transposition 345// combinations. 346// TODO(andydavis) Re-use this function for BatchMatMulGrad. 347Status MatMulGradCommon(const Scope& scope, const Operation& op, 348 const bool is_batch, 349 const std::vector<Output>& grad_inputs, 350 const string& attr_adj_x, const string& attr_adj_y, 351 std::vector<Output>* grad_outputs) { 352 DataType dtype; 353 TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->attrs(), "T", &dtype)); 354 if (dtype == DT_COMPLEX64 || dtype == DT_COMPLEX128) { 355 return errors::Unimplemented( 356 "MatMul gradient for complex data type is not supported yet."); 357 } 358 359 bool ta; 360 bool tb; 361 TF_RETURN_IF_ERROR( 362 GetNodeAttr(op.output(0).node()->attrs(), attr_adj_x, &ta)); 363 TF_RETURN_IF_ERROR( 364 GetNodeAttr(op.output(0).node()->attrs(), attr_adj_y, &tb)); 365 366 if (!ta && !tb) { 367 return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, op.input(1), 368 true, op.input(0), true, grad_inputs[0], false, 369 grad_outputs); 370 } else if (!ta && tb) { 371 return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, op.input(1), 372 false, grad_inputs[0], true, op.input(0), false, 373 grad_outputs); 374 } else if (ta && !tb) { 375 return MatMulGradHelper(scope, is_batch, op.input(1), false, grad_inputs[0], 376 true, op.input(0), false, grad_inputs[0], false, 377 grad_outputs); 378 } 379 return MatMulGradHelper(scope, is_batch, op.input(1), true, grad_inputs[0], 380 true, grad_inputs[0], true, op.input(0), true, 381 grad_outputs); 382} 383 384Status MatMulGrad(const Scope& scope, const Operation& op, 385 const std::vector<Output>& grad_inputs, 386 std::vector<Output>* grad_outputs) { 387 return MatMulGradCommon(scope, op, false, grad_inputs, "transpose_a", 388 "transpose_b", grad_outputs); 389} 390REGISTER_GRADIENT_OP("MatMul", MatMulGrad); 391 392Status BatchMatMulGrad(const Scope& scope, const Operation& op, 393 const std::vector<Output>& grad_inputs, 394 std::vector<Output>* grad_outputs) { 395 return MatMulGradCommon(scope, op, true, grad_inputs, "adj_x", "adj_y", 396 grad_outputs); 397} 398REGISTER_GRADIENT_OP("BatchMatMul", BatchMatMulGrad); 399 400} // anonymous namespace 401} // namespace ops 402} // namespace tensorflow 403