188c9fb09bd667df03cdf7e9f75ff225853ad01e1A. Unique TensorFlower# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
288c9fb09bd667df03cdf7e9f75ff225853ad01e1A. Unique TensorFlower#
388c9fb09bd667df03cdf7e9f75ff225853ad01e1A. Unique TensorFlower# Licensed under the Apache License, Version 2.0 (the "License");
488c9fb09bd667df03cdf7e9f75ff225853ad01e1A. Unique TensorFlower# you may not use this file except in compliance with the License.
588c9fb09bd667df03cdf7e9f75ff225853ad01e1A. Unique TensorFlower# You may obtain a copy of the License at
688c9fb09bd667df03cdf7e9f75ff225853ad01e1A. Unique TensorFlower#
788c9fb09bd667df03cdf7e9f75ff225853ad01e1A. Unique TensorFlower#     http://www.apache.org/licenses/LICENSE-2.0
888c9fb09bd667df03cdf7e9f75ff225853ad01e1A. Unique TensorFlower#
988c9fb09bd667df03cdf7e9f75ff225853ad01e1A. Unique TensorFlower# Unless required by applicable law or agreed to in writing, software
1088c9fb09bd667df03cdf7e9f75ff225853ad01e1A. Unique TensorFlower# distributed under the License is distributed on an "AS IS" BASIS,
1188c9fb09bd667df03cdf7e9f75ff225853ad01e1A. Unique TensorFlower# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1288c9fb09bd667df03cdf7e9f75ff225853ad01e1A. Unique TensorFlower# See the License for the specific language governing permissions and
1388c9fb09bd667df03cdf7e9f75ff225853ad01e1A. Unique TensorFlower# limitations under the License.
1488c9fb09bd667df03cdf7e9f75ff225853ad01e1A. Unique TensorFlower# ==============================================================================
1588c9fb09bd667df03cdf7e9f75ff225853ad01e1A. Unique TensorFlower
1688c9fb09bd667df03cdf7e9f75ff225853ad01e1A. Unique TensorFlower"""Test for version 2 of the zero_out op."""
1788c9fb09bd667df03cdf7e9f75ff225853ad01e1A. Unique TensorFlower
1888c9fb09bd667df03cdf7e9f75ff225853ad01e1A. Unique TensorFlowerfrom __future__ import absolute_import
1988c9fb09bd667df03cdf7e9f75ff225853ad01e1A. Unique TensorFlowerfrom __future__ import division
2088c9fb09bd667df03cdf7e9f75ff225853ad01e1A. Unique TensorFlowerfrom __future__ import print_function
2188c9fb09bd667df03cdf7e9f75ff225853ad01e1A. Unique TensorFlower
2288c9fb09bd667df03cdf7e9f75ff225853ad01e1A. Unique TensorFlowerimport tensorflow as tf
2388c9fb09bd667df03cdf7e9f75ff225853ad01e1A. Unique TensorFlower
2488c9fb09bd667df03cdf7e9f75ff225853ad01e1A. Unique TensorFlower
2588c9fb09bd667df03cdf7e9f75ff225853ad01e1A. Unique TensorFlowerfrom tensorflow.examples.adding_an_op import zero_out_grad_2  # pylint: disable=unused-import
2688c9fb09bd667df03cdf7e9f75ff225853ad01e1A. Unique TensorFlowerfrom tensorflow.examples.adding_an_op import zero_out_op_2
2788c9fb09bd667df03cdf7e9f75ff225853ad01e1A. Unique TensorFlower
2888c9fb09bd667df03cdf7e9f75ff225853ad01e1A. Unique TensorFlower
2988c9fb09bd667df03cdf7e9f75ff225853ad01e1A. Unique TensorFlowerclass ZeroOut2Test(tf.test.TestCase):
3088c9fb09bd667df03cdf7e9f75ff225853ad01e1A. Unique TensorFlower
3188c9fb09bd667df03cdf7e9f75ff225853ad01e1A. Unique TensorFlower  def test(self):
3288c9fb09bd667df03cdf7e9f75ff225853ad01e1A. Unique TensorFlower    with self.test_session():
3388c9fb09bd667df03cdf7e9f75ff225853ad01e1A. Unique TensorFlower      result = zero_out_op_2.zero_out([5, 4, 3, 2, 1])
3488c9fb09bd667df03cdf7e9f75ff225853ad01e1A. Unique TensorFlower      self.assertAllEqual(result.eval(), [5, 0, 0, 0, 0])
3588c9fb09bd667df03cdf7e9f75ff225853ad01e1A. Unique TensorFlower
3688c9fb09bd667df03cdf7e9f75ff225853ad01e1A. Unique TensorFlower  def test_2d(self):
3788c9fb09bd667df03cdf7e9f75ff225853ad01e1A. Unique TensorFlower    with self.test_session():
3888c9fb09bd667df03cdf7e9f75ff225853ad01e1A. Unique TensorFlower      result = zero_out_op_2.zero_out([[6, 5, 4], [3, 2, 1]])
3988c9fb09bd667df03cdf7e9f75ff225853ad01e1A. Unique TensorFlower      self.assertAllEqual(result.eval(), [[6, 0, 0], [0, 0, 0]])
4088c9fb09bd667df03cdf7e9f75ff225853ad01e1A. Unique TensorFlower
4188c9fb09bd667df03cdf7e9f75ff225853ad01e1A. Unique TensorFlower  def test_grad(self):
4288c9fb09bd667df03cdf7e9f75ff225853ad01e1A. Unique TensorFlower    with self.test_session():
4388c9fb09bd667df03cdf7e9f75ff225853ad01e1A. Unique TensorFlower      shape = (5,)
4488c9fb09bd667df03cdf7e9f75ff225853ad01e1A. Unique TensorFlower      x = tf.constant([5, 4, 3, 2, 1], dtype=tf.float32)
4588c9fb09bd667df03cdf7e9f75ff225853ad01e1A. Unique TensorFlower      y = zero_out_op_2.zero_out(x)
4688c9fb09bd667df03cdf7e9f75ff225853ad01e1A. Unique TensorFlower      err = tf.test.compute_gradient_error(x, shape, y, shape)
4788c9fb09bd667df03cdf7e9f75ff225853ad01e1A. Unique TensorFlower      self.assertLess(err, 1e-4)
4888c9fb09bd667df03cdf7e9f75ff225853ad01e1A. Unique TensorFlower
4988c9fb09bd667df03cdf7e9f75ff225853ad01e1A. Unique TensorFlower  def test_grad_2d(self):
5088c9fb09bd667df03cdf7e9f75ff225853ad01e1A. Unique TensorFlower    with self.test_session():
5188c9fb09bd667df03cdf7e9f75ff225853ad01e1A. Unique TensorFlower      shape = (2, 3)
5288c9fb09bd667df03cdf7e9f75ff225853ad01e1A. Unique TensorFlower      x = tf.constant([[6, 5, 4], [3, 2, 1]], dtype=tf.float32)
5388c9fb09bd667df03cdf7e9f75ff225853ad01e1A. Unique TensorFlower      y = zero_out_op_2.zero_out(x)
5488c9fb09bd667df03cdf7e9f75ff225853ad01e1A. Unique TensorFlower      err = tf.test.compute_gradient_error(x, shape, y, shape)
5588c9fb09bd667df03cdf7e9f75ff225853ad01e1A. Unique TensorFlower      self.assertLess(err, 1e-4)
5688c9fb09bd667df03cdf7e9f75ff225853ad01e1A. Unique TensorFlower
5788c9fb09bd667df03cdf7e9f75ff225853ad01e1A. Unique TensorFlower
5888c9fb09bd667df03cdf7e9f75ff225853ad01e1A. Unique TensorFlowerif __name__ == '__main__':
5988c9fb09bd667df03cdf7e9f75ff225853ad01e1A. Unique TensorFlower  tf.test.main()
60