1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for division with division imported from __future__.
16
17This file should be exactly the same as division_past_test.py except
18for the __future__ division line.
19"""
20
21from __future__ import absolute_import
22from __future__ import division
23from __future__ import print_function
24
25import numpy as np
26
27from tensorflow.python.framework import constant_op
28from tensorflow.python.framework import ops
29from tensorflow.python.platform import test
30
31
32class DivisionTestCase(test.TestCase):
33
34  def testDivision(self):
35    """Test all the different ways to divide."""
36    values = [1, 2, 7, 11]
37    functions = (lambda x: x), constant_op.constant
38    # TODO(irving): Test int8, int16 once we support casts for those.
39    dtypes = np.int32, np.int64, np.float32, np.float64
40
41    tensors = []
42    checks = []
43
44    def check(x, y):
45      x = ops.convert_to_tensor(x)
46      y = ops.convert_to_tensor(y)
47      tensors.append((x, y))
48      def f(x, y):
49        self.assertEqual(x.dtype, y.dtype)
50        self.assertEqual(x, y)
51      checks.append(f)
52
53    with self.test_session() as sess:
54      for dtype in dtypes:
55        for x in map(dtype, values):
56          for y in map(dtype, values):
57            for fx in functions:
58              for fy in functions:
59                tf_x = fx(x)
60                tf_y = fy(y)
61                div = x / y
62                tf_div = tf_x / tf_y
63                check(div, tf_div)
64                floordiv = x // y
65                tf_floordiv = tf_x // tf_y
66                check(floordiv, tf_floordiv)
67      # Do only one sess.run for speed
68      for f, (x, y) in zip(checks, sess.run(tensors)):
69        f(x, y)
70
71
72if __name__ == "__main__":
73  test.main()
74