math_grad.cc revision 6731e21360ac2fd2fa4999a12e87ae72acfcb2d2
1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include "tensorflow/cc/ops/standard_ops.h" 17 18#include "tensorflow/cc/framework/grad_op_registry.h" 19 20namespace tensorflow { 21namespace ops { 22namespace { 23 24// TODO(andydavis) Add control dependencies to gradient functions (as needed). 25 26Status AbsGrad(const Scope& scope, const Operation& op, 27 const std::vector<Output>& grad_inputs, 28 std::vector<Output>* grad_outputs) { 29 // dx = dy * sign(x) 30 grad_outputs->push_back(Mul(scope, grad_inputs[0], Sign(scope, op.input(0)))); 31 return scope.status(); 32} 33REGISTER_GRADIENT_OP("Abs", AbsGrad); 34 35Status NegGrad(const Scope& scope, const Operation& op, 36 const std::vector<Output>& grad_inputs, 37 std::vector<Output>* grad_outputs) { 38 // dx = -dy; 39 grad_outputs->push_back(Neg(scope, grad_inputs[0])); 40 return scope.status(); 41} 42REGISTER_GRADIENT_OP("Neg", NegGrad); 43 44Status InvGrad(const Scope& scope, const Operation& op, 45 const std::vector<Output>& grad_inputs, 46 std::vector<Output>* grad_outputs) { 47 // dx = dy * (-1 * (y * y)) 48 grad_outputs->push_back( 49 Mul(scope, grad_inputs[0], Neg(scope, Square(scope, op.output(0))))); 50 return scope.status(); 51} 52REGISTER_GRADIENT_OP("Inv", InvGrad); 53REGISTER_GRADIENT_OP("Reciprocal", InvGrad); 54 55Status SquareGrad(const Scope& scope, const Operation& op, 56 const std::vector<Output>& grad_inputs, 57 std::vector<Output>* grad_outputs) { 58 // dx = dy * (2 * x) 59 auto two = Cast(scope, Const(scope, 2), op.input(0).type()); 60 grad_outputs->push_back( 61 Mul(scope, grad_inputs[0], Mul(scope, two, op.input(0)))); 62 return scope.status(); 63} 64REGISTER_GRADIENT_OP("Square", SquareGrad); 65 66Status SqrtGrad(const Scope& scope, const Operation& op, 67 const std::vector<Output>& grad_inputs, 68 std::vector<Output>* grad_outputs) { 69 // y = sqrt(x) 70 // dy/dx = 0.5 * (1 / sqrt(x)) = 0.5 * (1 / y) 71 // dx = dy * (0.5 * (1 / y)) 72 auto y_inv = Reciprocal(scope, op.output(0)); 73 auto half = Cast(scope, Const(scope, 0.5), op.input(0).type()); 74 auto dx = Mul(scope, grad_inputs[0], Mul(scope, half, y_inv)); 75 grad_outputs->push_back(dx); 76 return scope.status(); 77} 78REGISTER_GRADIENT_OP("Sqrt", SqrtGrad); 79 80Status RsqrtGrad(const Scope& scope, const Operation& op, 81 const std::vector<Output>& grad_inputs, 82 std::vector<Output>* grad_outputs) { 83 // y = 1/x^1/2 = x^-1/2 84 // dy/dx = -1/2 * x^-3/2 = -1/2 * x^-1/2 * x^-1 = -1/2 * y * x^-1 85 // dx = dy * (-1/2 * y * x^-1) 86 auto x_inv = Reciprocal(scope, op.input(0)); 87 auto y = op.output(0); 88 auto neghalf = Cast(scope, Const(scope, -0.5), op.input(0).type()); 89 auto a = Mul(scope, neghalf, x_inv); 90 auto b = Mul(scope, a, y); 91 auto dx = Mul(scope, grad_inputs[0], b); 92 grad_outputs->push_back(dx); 93 return scope.status(); 94} 95REGISTER_GRADIENT_OP("Rsqrt", RsqrtGrad); 96 97Status ExpGrad(const Scope& scope, const Operation& op, 98 const std::vector<Output>& grad_inputs, 99 std::vector<Output>* grad_outputs) { 100 // y = exp(x) 101 // dy/dx = exp(x) 102 // dx = dy * y 103 grad_outputs->push_back(Mul(scope, grad_inputs[0], op.output(0))); 104 return scope.status(); 105} 106REGISTER_GRADIENT_OP("Exp", ExpGrad); 107 108Status Expm1Grad(const Scope& scope, const Operation& op, 109 const std::vector<Output>& grad_inputs, 110 std::vector<Output>* grad_outputs) { 111 // f(x) = expm1(x) 112 // df/dx = exp(x) 113 // dx = dy * exp(x) 114 grad_outputs->push_back(Mul(scope, grad_inputs[0], Exp(scope, op.input(0)))); 115 return scope.status(); 116} 117REGISTER_GRADIENT_OP("Expm1", Expm1Grad); 118 119Status LogGrad(const Scope& scope, const Operation& op, 120 const std::vector<Output>& grad_inputs, 121 std::vector<Output>* grad_outputs) { 122 // f(x) = log(x) = y 123 // df/dx = 1 / x 124 // dx = dy * (1 / x) 125 grad_outputs->push_back( 126 Mul(scope, grad_inputs[0], Reciprocal(scope, op.input(0)))); 127 return scope.status(); 128} 129REGISTER_GRADIENT_OP("Log", LogGrad); 130 131Status Log1pGrad(const Scope& scope, const Operation& op, 132 const std::vector<Output>& grad_inputs, 133 std::vector<Output>* grad_outputs) { 134 // f(x) = log1p(x) = y 135 // df/dx = 1 / (1 + x) 136 // dx = dy * (1 / (1 + x)) 137 auto one = Cast(scope, Const(scope, 1.0), op.input(0).type()); 138 grad_outputs->push_back( 139 Div(scope, grad_inputs[0], Add(scope, one, op.input(0)))); 140 return scope.status(); 141} 142REGISTER_GRADIENT_OP("Log1p", Log1pGrad); 143 144Status TanhGrad(const Scope& scope, const Operation& op, 145 const std::vector<Output>& grad_inputs, 146 std::vector<Output>* grad_outputs) { 147 // y = tanh(x) 148 // dy/dx = 1 - (tanh(x))^2 = 1 - y^2 149 // dx = dy * (1 - y^2) 150 auto y2 = Square(scope, op.output(0)); 151 auto one = Cast(scope, Const(scope, 1.0), op.input(0).type()); 152 auto dx = Mul(scope, grad_inputs[0], Sub(scope, one, y2)); 153 grad_outputs->push_back(dx); 154 return scope.status(); 155} 156REGISTER_GRADIENT_OP("Tanh", TanhGrad); 157 158Status SigmoidGrad(const Scope& scope, const Operation& op, 159 const std::vector<Output>& grad_inputs, 160 std::vector<Output>* grad_outputs) { 161 // y = 1 / (1 + exp(-x)) 162 // dy/dx = y * (1 - y) 163 // dx = dy * y * (1 - y) 164 auto y = op.output(0); 165 auto one = Cast(scope, Const(scope, 1.0), op.input(0).type()); 166 auto dx = Mul(scope, grad_inputs[0], Mul(scope, y, Sub(scope, one, y))); 167 grad_outputs->push_back(dx); 168 return scope.status(); 169} 170REGISTER_GRADIENT_OP("Sigmoid", SigmoidGrad); 171 172Status SignGrad(const Scope& scope, const Operation& op, 173 const std::vector<Output>& grad_inputs, 174 std::vector<Output>* grad_outputs) { 175 auto shape = Shape(scope, op.input(0)); 176 auto zero = Cast(scope, Const(scope, 0.0), op.input(0).type()); 177 auto dx = Fill(scope, shape, zero); 178 grad_outputs->push_back(dx); 179 return scope.status(); 180} 181REGISTER_GRADIENT_OP("Sign", SignGrad); 182 183Status SinGrad(const Scope& scope, const Operation& op, 184 const std::vector<Output>& grad_inputs, 185 std::vector<Output>* grad_outputs) { 186 // y = sin(x) 187 // dy/dx = cos(x) 188 // dx = dy * cos(x) 189 auto dx = Mul(scope, grad_inputs[0], Cos(scope, op.input(0))); 190 grad_outputs->push_back(dx); 191 return scope.status(); 192} 193REGISTER_GRADIENT_OP("Sin", SinGrad); 194 195Status CosGrad(const Scope& scope, const Operation& op, 196 const std::vector<Output>& grad_inputs, 197 std::vector<Output>* grad_outputs) { 198 // y = cos(x) 199 // dy/dx = -sin(x) 200 // dx = dy * -sin(x) 201 auto dx = Mul(scope, grad_inputs[0], Neg(scope, Sin(scope, op.input(0)))); 202 grad_outputs->push_back(dx); 203 return scope.status(); 204} 205REGISTER_GRADIENT_OP("Cos", CosGrad); 206 207Status AsinGrad(const Scope& scope, const Operation& op, 208 const std::vector<Output>& grad_inputs, 209 std::vector<Output>* grad_outputs) { 210 // y = asin(x) 211 // dy/dx = 1 / (1 - x * x)^1/2 212 // dx = dy * (1 / (1 - x * x)^1/2) 213 auto x2 = Square(scope, op.input(0)); 214 auto one = Cast(scope, Const(scope, 1.0), op.input(0).type()); 215 auto dydx = Reciprocal(scope, Sqrt(scope, Sub(scope, one, x2))); 216 auto dx = Mul(scope, grad_inputs[0], dydx); 217 grad_outputs->push_back(dx); 218 return scope.status(); 219} 220REGISTER_GRADIENT_OP("Asin", AsinGrad); 221 222Status AcosGrad(const Scope& scope, const Operation& op, 223 const std::vector<Output>& grad_inputs, 224 std::vector<Output>* grad_outputs) { 225 // y = acos(x) 226 // dy/dx = - 1 / (1 - x * x)^1/2 227 // dx = dy * (- 1 / (1 - x * x)^1/2) 228 auto x2 = Square(scope, op.input(0)); 229 auto one = Cast(scope, Const(scope, 1.0), op.input(0).type()); 230 auto dydx = Neg(scope, Reciprocal(scope, Sqrt(scope, Sub(scope, one, x2)))); 231 auto dx = Mul(scope, grad_inputs[0], dydx); 232 grad_outputs->push_back(dx); 233 return scope.status(); 234} 235REGISTER_GRADIENT_OP("Acos", AcosGrad); 236 237Status TanGrad(const Scope& scope, const Operation& op, 238 const std::vector<Output>& grad_inputs, 239 std::vector<Output>* grad_outputs) { 240 // y = tan(x) 241 // dy/dx = sec(x)^2 = 1 / cos(x)^2 242 // dx = dy * (1 / cos(x)^2) 243 auto dydx = Square(scope, Reciprocal(scope, Cos(scope, op.input(0)))); 244 auto dx = Mul(scope, grad_inputs[0], dydx); 245 grad_outputs->push_back(dx); 246 return scope.status(); 247} 248REGISTER_GRADIENT_OP("Tan", TanGrad); 249 250Status AtanGrad(const Scope& scope, const Operation& op, 251 const std::vector<Output>& grad_inputs, 252 std::vector<Output>* grad_outputs) { 253 // y = arctan(x) 254 // dy/dx = 1 / (1 + x^2) 255 // dx = dy * (1 / (1 + x^2) 256 auto one = Cast(scope, Const(scope, 1.0), op.input(0).type()); 257 auto dydx = Reciprocal(scope, Add(scope, one, Square(scope, op.input(0)))); 258 auto dx = Mul(scope, grad_inputs[0], dydx); 259 grad_outputs->push_back(dx); 260 return scope.status(); 261} 262REGISTER_GRADIENT_OP("Atan", AtanGrad); 263 264Status RealGrad(const Scope& scope, const Operation& op, 265 const std::vector<Output>& grad_inputs, 266 std::vector<Output>* grad_outputs) { 267 auto zero = Cast(scope, Const(scope, 0.0), op.output(0).type()); 268 auto dx = Complex(scope, grad_inputs[0], zero); 269 grad_outputs->push_back(dx); 270 return scope.status(); 271} 272REGISTER_GRADIENT_OP("Real", RealGrad); 273 274Status ImagGrad(const Scope& scope, const Operation& op, 275 const std::vector<Output>& grad_inputs, 276 std::vector<Output>* grad_outputs) { 277 auto zero = Cast(scope, Const(scope, 0.0), op.output(0).type()); 278 auto dx = Complex(scope, zero, grad_inputs[0]); 279 grad_outputs->push_back(dx); 280 return scope.status(); 281} 282REGISTER_GRADIENT_OP("Imag", ImagGrad); 283 284Status ConjGrad(const Scope& scope, const Operation& op, 285 const std::vector<Output>& grad_inputs, 286 std::vector<Output>* grad_outputs) { 287 grad_outputs->push_back(Conj(scope, grad_inputs[0])); 288 return scope.status(); 289} 290REGISTER_GRADIENT_OP("Conj", ConjGrad); 291 292// MatMulGrad helper function used to compute two MatMul operations 293// based on input matrix transposition combinations. 294Status MatMulGradHelper(const Scope& scope, const bool is_batch, 295 const Output& x0, const bool adj_x0, const Output& x1, 296 const bool adj_x1, const Output& y0, const bool adj_y0, 297 const Output& y1, const bool adj_y1, 298 std::vector<Output>* grad_outputs) { 299 if (is_batch == false) { 300 auto dx = 301 MatMul(scope, x0, x1, MatMul::TransposeA(adj_x0).TransposeB(adj_x1)); 302 grad_outputs->push_back(dx); 303 auto dy = 304 MatMul(scope, y0, y1, MatMul::TransposeA(adj_y0).TransposeB(adj_y1)); 305 grad_outputs->push_back(dy); 306 } else { 307 auto dx = 308 BatchMatMul(scope, x0, x1, BatchMatMul::AdjX(adj_x0).AdjY(adj_x1)); 309 grad_outputs->push_back(dx); 310 auto dy = 311 BatchMatMul(scope, y0, y1, BatchMatMul::AdjX(adj_y0).AdjY(adj_y1)); 312 grad_outputs->push_back(dy); 313 } 314 return scope.status(); 315} 316 317// MatMulGrad common used to read and check node attr state, and determine 318// proper MatMul products for gradients based on input matrix transposition 319// combinations. 320// TODO(andydavis) Re-use this function for BatchMatMulGrad. 321Status MatMulGradCommon(const Scope& scope, const Operation& op, 322 const bool is_batch, 323 const std::vector<Output>& grad_inputs, 324 const string& attr_adj_x, const string& attr_adj_y, 325 std::vector<Output>* grad_outputs) { 326 DataType dtype; 327 TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->def(), "T", &dtype)); 328 if (dtype == DT_COMPLEX64 || dtype == DT_COMPLEX128) { 329 return errors::Unimplemented( 330 "MatMul gradient for complex data type is not supported yet."); 331 } 332 333 bool ta; 334 bool tb; 335 TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->def(), attr_adj_x, &ta)); 336 TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->def(), attr_adj_y, &tb)); 337 338 if (!ta && !tb) { 339 return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, op.input(1), 340 true, op.input(0), true, grad_inputs[0], false, 341 grad_outputs); 342 } else if (!ta && tb) { 343 return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, op.input(1), 344 false, grad_inputs[0], true, op.input(0), false, 345 grad_outputs); 346 } else if (ta && !tb) { 347 return MatMulGradHelper(scope, is_batch, op.input(1), false, grad_inputs[0], 348 true, op.input(0), false, grad_inputs[0], false, 349 grad_outputs); 350 } 351 return MatMulGradHelper(scope, is_batch, op.input(1), true, grad_inputs[0], 352 true, grad_inputs[0], true, op.input(0), true, 353 grad_outputs); 354} 355 356Status MatMulGrad(const Scope& scope, const Operation& op, 357 const std::vector<Output>& grad_inputs, 358 std::vector<Output>* grad_outputs) { 359 return MatMulGradCommon(scope, op, false, grad_inputs, "transpose_a", 360 "transpose_b", grad_outputs); 361} 362REGISTER_GRADIENT_OP("MatMul", MatMulGrad); 363 364Status BatchMatMulGrad(const Scope& scope, const Operation& op, 365 const std::vector<Output>& grad_inputs, 366 std::vector<Output>* grad_outputs) { 367 return MatMulGradCommon(scope, op, true, grad_inputs, "adj_x", "adj_y", 368 grad_outputs); 369} 370REGISTER_GRADIENT_OP("BatchMatMul", BatchMatMulGrad); 371 372} // anonymous namespace 373} // namespace ops 374} // namespace tensorflow 375