math_grad.cc revision b2823a416f876bfbff3022d1b61e358525b40618
1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include "tensorflow/cc/ops/standard_ops.h" 17 18#include "tensorflow/cc/framework/grad_op_registry.h" 19 20namespace tensorflow { 21namespace ops { 22namespace { 23 24// TODO(andydavis) Add control dependencies to gradient functions (as needed). 25 26Status AbsGrad(const Scope& scope, const Operation& op, 27 const std::vector<Output>& grad_inputs, 28 std::vector<Output>* grad_outputs) { 29 // dx = dy * sign(x) 30 grad_outputs->push_back(Mul(scope, grad_inputs[0], Sign(scope, op.input(0)))); 31 return scope.status(); 32} 33REGISTER_GRADIENT_OP("Abs", AbsGrad); 34 35Status NegGrad(const Scope& scope, const Operation& op, 36 const std::vector<Output>& grad_inputs, 37 std::vector<Output>* grad_outputs) { 38 // dx = -dy; 39 grad_outputs->push_back(Neg(scope, grad_inputs[0])); 40 return scope.status(); 41} 42REGISTER_GRADIENT_OP("Neg", NegGrad); 43 44Status InvGrad(const Scope& scope, const Operation& op, 45 const std::vector<Output>& grad_inputs, 46 std::vector<Output>* grad_outputs) { 47 // dx = dy * (-1 * (y * y)) 48 grad_outputs->push_back( 49 Mul(scope, grad_inputs[0], Neg(scope, Square(scope, op.output(0))))); 50 return scope.status(); 51} 52REGISTER_GRADIENT_OP("Inv", InvGrad); 53REGISTER_GRADIENT_OP("Reciprocal", InvGrad); 54 55Status SquareGrad(const Scope& scope, const Operation& op, 56 const std::vector<Output>& grad_inputs, 57 std::vector<Output>* grad_outputs) { 58 // dx = dy * (2 * x) 59 auto two = Cast(scope, Const(scope, 2), op.input(0).type()); 60 grad_outputs->push_back( 61 Mul(scope, grad_inputs[0], Mul(scope, two, op.input(0)))); 62 return scope.status(); 63} 64REGISTER_GRADIENT_OP("Square", SquareGrad); 65 66Status SqrtGrad(const Scope& scope, const Operation& op, 67 const std::vector<Output>& grad_inputs, 68 std::vector<Output>* grad_outputs) { 69 // y = sqrt(x) 70 // dy/dx = 0.5 * (1 / sqrt(x)) = 0.5 * (1 / y) 71 // dx = dy * (0.5 * (1 / y)) 72 auto y_inv = Reciprocal(scope, op.output(0)); 73 auto half = Cast(scope, Const(scope, 0.5), op.input(0).type()); 74 auto dx = Mul(scope, grad_inputs[0], Mul(scope, half, y_inv)); 75 grad_outputs->push_back(dx); 76 return scope.status(); 77} 78REGISTER_GRADIENT_OP("Sqrt", SqrtGrad); 79 80Status RsqrtGrad(const Scope& scope, const Operation& op, 81 const std::vector<Output>& grad_inputs, 82 std::vector<Output>* grad_outputs) { 83 // y = 1/x^1/2 = x^-1/2 84 // dy/dx = -1/2 * x^-3/2 = -1/2 * x^-1/2 * x^-1 = -1/2 * y * x^-1 85 // dx = dy * (-1/2 * y * x^-1) 86 auto x_inv = Reciprocal(scope, op.input(0)); 87 auto y = op.output(0); 88 auto neghalf = Cast(scope, Const(scope, -0.5), op.input(0).type()); 89 auto a = Mul(scope, neghalf, x_inv); 90 auto b = Mul(scope, a, y); 91 auto dx = Mul(scope, grad_inputs[0], b); 92 grad_outputs->push_back(dx); 93 return scope.status(); 94} 95REGISTER_GRADIENT_OP("Rsqrt", RsqrtGrad); 96 97Status ExpGrad(const Scope& scope, const Operation& op, 98 const std::vector<Output>& grad_inputs, 99 std::vector<Output>* grad_outputs) { 100 // y = exp(x) 101 // dy/dx = exp(x) 102 // dx = dy * y 103 grad_outputs->push_back(Mul(scope, grad_inputs[0], op.output(0))); 104 return scope.status(); 105} 106REGISTER_GRADIENT_OP("Exp", ExpGrad); 107 108Status LogGrad(const Scope& scope, const Operation& op, 109 const std::vector<Output>& grad_inputs, 110 std::vector<Output>* grad_outputs) { 111 // f(x) = log(x) = y 112 // df/dx = 1 / x 113 // dx = dy * (1 / x) 114 grad_outputs->push_back( 115 Mul(scope, grad_inputs[0], Reciprocal(scope, op.input(0)))); 116 return scope.status(); 117} 118REGISTER_GRADIENT_OP("Log", LogGrad); 119 120Status Log1pGrad(const Scope& scope, const Operation& op, 121 const std::vector<Output>& grad_inputs, 122 std::vector<Output>* grad_outputs) { 123 // f(x) = log1p(x) = y 124 // df/dx = 1 / (1 + x) 125 // dx = dy * (1 / (1 + x)) 126 auto one = Cast(scope, Const(scope, 1.0), op.input(0).type()); 127 grad_outputs->push_back( 128 Div(scope, grad_inputs[0], Add(scope, one, op.input(0)))); 129 return scope.status(); 130} 131REGISTER_GRADIENT_OP("Log1p", Log1pGrad); 132 133Status TanhGrad(const Scope& scope, const Operation& op, 134 const std::vector<Output>& grad_inputs, 135 std::vector<Output>* grad_outputs) { 136 // y = tanh(x) 137 // dy/dx = 1 - (tanh(x))^2 = 1 - y^2 138 // dx = dy * (1 - y^2) 139 auto y2 = Square(scope, op.output(0)); 140 auto one = Cast(scope, Const(scope, 1.0), op.input(0).type()); 141 auto dx = Mul(scope, grad_inputs[0], Sub(scope, one, y2)); 142 grad_outputs->push_back(dx); 143 return scope.status(); 144} 145REGISTER_GRADIENT_OP("Tanh", TanhGrad); 146 147Status SigmoidGrad(const Scope& scope, const Operation& op, 148 const std::vector<Output>& grad_inputs, 149 std::vector<Output>* grad_outputs) { 150 // y = 1 / (1 + exp(-x)) 151 // dy/dx = y * (1 - y) 152 // dx = dy * y * (1 - y) 153 auto y = op.output(0); 154 auto one = Cast(scope, Const(scope, 1.0), op.input(0).type()); 155 auto dx = Mul(scope, grad_inputs[0], Mul(scope, y, Sub(scope, one, y))); 156 grad_outputs->push_back(dx); 157 return scope.status(); 158} 159REGISTER_GRADIENT_OP("Sigmoid", SigmoidGrad); 160 161Status SignGrad(const Scope& scope, const Operation& op, 162 const std::vector<Output>& grad_inputs, 163 std::vector<Output>* grad_outputs) { 164 auto shape = Shape(scope, op.input(0)); 165 auto zero = Cast(scope, Const(scope, 0.0), op.input(0).type()); 166 auto dx = Fill(scope, shape, zero); 167 grad_outputs->push_back(dx); 168 return scope.status(); 169} 170REGISTER_GRADIENT_OP("Sign", SignGrad); 171 172Status SinGrad(const Scope& scope, const Operation& op, 173 const std::vector<Output>& grad_inputs, 174 std::vector<Output>* grad_outputs) { 175 // y = sin(x) 176 // dy/dx = cos(x) 177 // dx = dy * cos(x) 178 auto dx = Mul(scope, grad_inputs[0], Cos(scope, op.input(0))); 179 grad_outputs->push_back(dx); 180 return scope.status(); 181} 182REGISTER_GRADIENT_OP("Sin", SinGrad); 183 184Status CosGrad(const Scope& scope, const Operation& op, 185 const std::vector<Output>& grad_inputs, 186 std::vector<Output>* grad_outputs) { 187 // y = cos(x) 188 // dy/dx = -sin(x) 189 // dx = dy * -sin(x) 190 auto dx = Mul(scope, grad_inputs[0], Neg(scope, Sin(scope, op.input(0)))); 191 grad_outputs->push_back(dx); 192 return scope.status(); 193} 194REGISTER_GRADIENT_OP("Cos", CosGrad); 195 196Status AsinGrad(const Scope& scope, const Operation& op, 197 const std::vector<Output>& grad_inputs, 198 std::vector<Output>* grad_outputs) { 199 // y = asin(x) 200 // dy/dx = 1 / (1 - x * x)^1/2 201 // dx = dy * (1 / (1 - x * x)^1/2) 202 auto x2 = Square(scope, op.input(0)); 203 auto one = Cast(scope, Const(scope, 1.0), op.input(0).type()); 204 auto dydx = Reciprocal(scope, Sqrt(scope, Sub(scope, one, x2))); 205 auto dx = Mul(scope, grad_inputs[0], dydx); 206 grad_outputs->push_back(dx); 207 return scope.status(); 208} 209REGISTER_GRADIENT_OP("Asin", AsinGrad); 210 211Status AcosGrad(const Scope& scope, const Operation& op, 212 const std::vector<Output>& grad_inputs, 213 std::vector<Output>* grad_outputs) { 214 // y = acos(x) 215 // dy/dx = - 1 / (1 - x * x)^1/2 216 // dx = dy * (- 1 / (1 - x * x)^1/2) 217 auto x2 = Square(scope, op.input(0)); 218 auto one = Cast(scope, Const(scope, 1.0), op.input(0).type()); 219 auto dydx = Neg(scope, Reciprocal(scope, Sqrt(scope, Sub(scope, one, x2)))); 220 auto dx = Mul(scope, grad_inputs[0], dydx); 221 grad_outputs->push_back(dx); 222 return scope.status(); 223} 224REGISTER_GRADIENT_OP("Acos", AcosGrad); 225 226Status TanGrad(const Scope& scope, const Operation& op, 227 const std::vector<Output>& grad_inputs, 228 std::vector<Output>* grad_outputs) { 229 // y = tan(x) 230 // dy/dx = sec(x)^2 = 1 / cos(x)^2 231 // dx = dy * (1 / cos(x)^2) 232 auto dydx = Square(scope, Reciprocal(scope, Cos(scope, op.input(0)))); 233 auto dx = Mul(scope, grad_inputs[0], dydx); 234 grad_outputs->push_back(dx); 235 return scope.status(); 236} 237REGISTER_GRADIENT_OP("Tan", TanGrad); 238 239Status AtanGrad(const Scope& scope, const Operation& op, 240 const std::vector<Output>& grad_inputs, 241 std::vector<Output>* grad_outputs) { 242 // y = arctan(x) 243 // dy/dx = 1 / (1 + x^2) 244 // dx = dy * (1 / (1 + x^2) 245 auto one = Cast(scope, Const(scope, 1.0), op.input(0).type()); 246 auto dydx = Reciprocal(scope, Add(scope, one, Square(scope, op.input(0)))); 247 auto dx = Mul(scope, grad_inputs[0], dydx); 248 grad_outputs->push_back(dx); 249 return scope.status(); 250} 251REGISTER_GRADIENT_OP("Atan", AtanGrad); 252 253Status RealGrad(const Scope& scope, const Operation& op, 254 const std::vector<Output>& grad_inputs, 255 std::vector<Output>* grad_outputs) { 256 auto zero = Cast(scope, Const(scope, 0.0), op.output(0).type()); 257 auto dx = Complex(scope, grad_inputs[0], zero); 258 grad_outputs->push_back(dx); 259 return scope.status(); 260} 261REGISTER_GRADIENT_OP("Real", RealGrad); 262 263Status ImagGrad(const Scope& scope, const Operation& op, 264 const std::vector<Output>& grad_inputs, 265 std::vector<Output>* grad_outputs) { 266 auto zero = Cast(scope, Const(scope, 0.0), op.output(0).type()); 267 auto dx = Complex(scope, zero, grad_inputs[0]); 268 grad_outputs->push_back(dx); 269 return scope.status(); 270} 271REGISTER_GRADIENT_OP("Imag", ImagGrad); 272 273Status ConjGrad(const Scope& scope, const Operation& op, 274 const std::vector<Output>& grad_inputs, 275 std::vector<Output>* grad_outputs) { 276 grad_outputs->push_back(Conj(scope, grad_inputs[0])); 277 return scope.status(); 278} 279REGISTER_GRADIENT_OP("Conj", ConjGrad); 280 281// MatMulGrad helper function used to compute two MatMul operations 282// based on input matrix transposition combinations. 283Status MatMulGradHelper(const Scope& scope, const bool is_batch, 284 const Output& x0, const bool adj_x0, const Output& x1, 285 const bool adj_x1, const Output& y0, const bool adj_y0, 286 const Output& y1, const bool adj_y1, 287 std::vector<Output>* grad_outputs) { 288 if (is_batch == false) { 289 auto dx = 290 MatMul(scope, x0, x1, MatMul::TransposeA(adj_x0).TransposeB(adj_x1)); 291 grad_outputs->push_back(dx); 292 auto dy = 293 MatMul(scope, y0, y1, MatMul::TransposeA(adj_y0).TransposeB(adj_y1)); 294 grad_outputs->push_back(dy); 295 } else { 296 auto dx = 297 BatchMatMul(scope, x0, x1, BatchMatMul::AdjX(adj_x0).AdjY(adj_x1)); 298 grad_outputs->push_back(dx); 299 auto dy = 300 BatchMatMul(scope, y0, y1, BatchMatMul::AdjX(adj_y0).AdjY(adj_y1)); 301 grad_outputs->push_back(dy); 302 } 303 return scope.status(); 304} 305 306// MatMulGrad common used to read and check node attr state, and determine 307// proper MatMul products for gradients based on input matrix transposition 308// combinations. 309// TODO(andydavis) Re-use this function for BatchMatMulGrad. 310Status MatMulGradCommon(const Scope& scope, const Operation& op, 311 const bool is_batch, 312 const std::vector<Output>& grad_inputs, 313 const string& attr_adj_x, const string& attr_adj_y, 314 std::vector<Output>* grad_outputs) { 315 DataType dtype; 316 TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->def(), "T", &dtype)); 317 if (dtype == DT_COMPLEX64 || dtype == DT_COMPLEX128) { 318 return errors::Unimplemented( 319 "MatMul gradient for complex data type is not supported yet."); 320 } 321 322 bool ta; 323 bool tb; 324 TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->def(), attr_adj_x, &ta)); 325 TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->def(), attr_adj_y, &tb)); 326 327 if (!ta && !tb) { 328 return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, op.input(1), 329 true, op.input(0), true, grad_inputs[0], false, 330 grad_outputs); 331 } else if (!ta && tb) { 332 return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, op.input(1), 333 false, grad_inputs[0], true, op.input(0), false, 334 grad_outputs); 335 } else if (ta && !tb) { 336 return MatMulGradHelper(scope, is_batch, op.input(1), false, grad_inputs[0], 337 true, op.input(0), false, grad_inputs[0], false, 338 grad_outputs); 339 } 340 return MatMulGradHelper(scope, is_batch, op.input(1), true, grad_inputs[0], 341 true, grad_inputs[0], true, op.input(0), true, 342 grad_outputs); 343} 344 345Status MatMulGrad(const Scope& scope, const Operation& op, 346 const std::vector<Output>& grad_inputs, 347 std::vector<Output>* grad_outputs) { 348 return MatMulGradCommon(scope, op, false, grad_inputs, "transpose_a", 349 "transpose_b", grad_outputs); 350} 351REGISTER_GRADIENT_OP("MatMul", MatMulGrad); 352 353Status BatchMatMulGrad(const Scope& scope, const Operation& op, 354 const std::vector<Output>& grad_inputs, 355 std::vector<Output>* grad_outputs) { 356 return MatMulGradCommon(scope, op, true, grad_inputs, "adj_x", "adj_y", 357 grad_outputs); 358} 359REGISTER_GRADIENT_OP("BatchMatMul", BatchMatMulGrad); 360 361} // anonymous namespace 362} // namespace ops 363} // namespace tensorflow 364