1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""The gradient of the tutorial zero_out op."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22from tensorflow.python.framework import ops
23from tensorflow.python.ops import array_ops
24from tensorflow.python.ops import sparse_ops
25
26
27@ops.RegisterGradient("ZeroOut")
28def _zero_out_grad(op, grad):
29  """The gradients for `zero_out`.
30
31  Args:
32    op: The `zero_out` `Operation` that we are differentiating, which we can use
33      to find the inputs and outputs of the original op.
34    grad: Gradient with respect to the output of the `zero_out` op.
35
36  Returns:
37    Gradients with respect to the input of `zero_out`.
38  """
39  to_zero = op.inputs[0]
40  shape = array_ops.shape(to_zero)
41  index = array_ops.zeros_like(shape)
42  first_grad = array_ops.reshape(grad, [-1])[0]
43  to_zero_grad = sparse_ops.sparse_to_dense([index], shape, first_grad, 0)
44  return [to_zero_grad]  # List of one Tensor, since we have one input
45