1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""The gradient of the tutorial zero_out op.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22from tensorflow.python.framework import ops 23from tensorflow.python.ops import array_ops 24from tensorflow.python.ops import sparse_ops 25 26 27@ops.RegisterGradient("ZeroOut") 28def _zero_out_grad(op, grad): 29 """The gradients for `zero_out`. 30 31 Args: 32 op: The `zero_out` `Operation` that we are differentiating, which we can use 33 to find the inputs and outputs of the original op. 34 grad: Gradient with respect to the output of the `zero_out` op. 35 36 Returns: 37 Gradients with respect to the input of `zero_out`. 38 """ 39 to_zero = op.inputs[0] 40 shape = array_ops.shape(to_zero) 41 index = array_ops.zeros_like(shape) 42 first_grad = array_ops.reshape(grad, [-1])[0] 43 to_zero_grad = sparse_ops.sparse_to_dense([index], shape, first_grad, 0) 44 return [to_zero_grad] # List of one Tensor, since we have one input 45