1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for merge layers."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.python.keras._impl import keras
24from tensorflow.python.ops import array_ops
25from tensorflow.python.platform import test
26
27
28class MergeLayersTest(test.TestCase):
29
30  def test_merge_add(self):
31    with self.test_session():
32      i1 = keras.layers.Input(shape=(4, 5))
33      i2 = keras.layers.Input(shape=(4, 5))
34      i3 = keras.layers.Input(shape=(4, 5))
35
36      o = keras.layers.add([i1, i2, i3])
37      self.assertListEqual(o.get_shape().as_list(), [None, 4, 5])
38      model = keras.models.Model([i1, i2, i3], o)
39
40      x1 = np.random.random((2, 4, 5))
41      x2 = np.random.random((2, 4, 5))
42      x3 = np.random.random((2, 4, 5))
43      out = model.predict([x1, x2, x3])
44      self.assertEqual(out.shape, (2, 4, 5))
45      self.assertAllClose(out, x1 + x2 + x3, atol=1e-4)
46
47      # test masking
48      i1 = keras.layers.Input(shape=(4, 5))
49      i2 = keras.layers.Input(shape=(4, 5))
50      m1 = keras.layers.Masking()(i1)
51      layer = keras.layers.Add()
52      o = layer([m1, i2])
53      self.assertListEqual(o.get_shape().as_list(), [None, 4, 5])
54      mask = layer.output_mask
55      self.assertListEqual(mask.get_shape().as_list(), [None, 4])
56
57      # test missing shape
58      i1 = array_ops.placeholder(shape=(4, None), dtype='float32')
59      i2 = array_ops.placeholder(shape=(4, 5), dtype='float32')
60      layer = keras.layers.Add()
61      o = layer([i1, i2])
62
63  def test_merge_elementwise_errors(self):
64    i1 = keras.layers.Input(shape=(4, 5))
65    i2 = keras.layers.Input(shape=(4, 6))
66    with self.assertRaises(ValueError):
67      keras.layers.add([i1, i2])
68    with self.assertRaises(ValueError):
69      keras.layers.add([i1])
70    with self.assertRaises(ValueError):
71      keras.layers.add(i1)
72    with self.assertRaises(ValueError):
73      keras.layers.add([i1])
74
75  def test_merge_multiply(self):
76    with self.test_session():
77      i1 = keras.layers.Input(shape=(4, 5))
78      i2 = keras.layers.Input(shape=(4, 5))
79      i3 = keras.layers.Input(shape=(4, 5))
80      o = keras.layers.multiply([i1, i2, i3])
81      self.assertListEqual(o.get_shape().as_list(), [None, 4, 5])
82      model = keras.models.Model([i1, i2, i3], o)
83
84      x1 = np.random.random((2, 4, 5))
85      x2 = np.random.random((2, 4, 5))
86      x3 = np.random.random((2, 4, 5))
87      out = model.predict([x1, x2, x3])
88      self.assertEqual(out.shape, (2, 4, 5))
89      self.assertAllClose(out, x1 * x2 * x3, atol=1e-4)
90
91  def test_merge_average(self):
92    with self.test_session():
93      i1 = keras.layers.Input(shape=(4, 5))
94      i2 = keras.layers.Input(shape=(4, 5))
95      o = keras.layers.average([i1, i2])
96      self.assertListEqual(o.get_shape().as_list(), [None, 4, 5])
97      model = keras.models.Model([i1, i2], o)
98
99      x1 = np.random.random((2, 4, 5))
100      x2 = np.random.random((2, 4, 5))
101      out = model.predict([x1, x2])
102      self.assertEqual(out.shape, (2, 4, 5))
103      self.assertAllClose(out, 0.5 * (x1 + x2), atol=1e-4)
104
105  def test_merge_maximum(self):
106    with self.test_session():
107      i1 = keras.layers.Input(shape=(4, 5))
108      i2 = keras.layers.Input(shape=(4, 5))
109      o = keras.layers.maximum([i1, i2])
110      self.assertListEqual(o.get_shape().as_list(), [None, 4, 5])
111      model = keras.models.Model([i1, i2], o)
112
113      x1 = np.random.random((2, 4, 5))
114      x2 = np.random.random((2, 4, 5))
115      out = model.predict([x1, x2])
116      self.assertEqual(out.shape, (2, 4, 5))
117      self.assertAllClose(out, np.maximum(x1, x2), atol=1e-4)
118
119  def test_merge_minimum(self):
120    with self.test_session():
121      i1 = keras.layers.Input(shape=(4, 5))
122      i2 = keras.layers.Input(shape=(4, 5))
123      o = keras.layers.minimum([i1, i2])
124      self.assertListEqual(o.get_shape().as_list(), [None, 4, 5])
125      model = keras.models.Model([i1, i2], o)
126
127      x1 = np.random.random((2, 4, 5))
128      x2 = np.random.random((2, 4, 5))
129      out = model.predict([x1, x2])
130      self.assertEqual(out.shape, (2, 4, 5))
131      self.assertAllClose(out, np.minimum(x1, x2), atol=1e-4)
132
133  def test_merge_concatenate(self):
134    with self.test_session():
135      i1 = keras.layers.Input(shape=(4, 5))
136      i2 = keras.layers.Input(shape=(4, 5))
137      o = keras.layers.concatenate([i1, i2], axis=1)
138      self.assertListEqual(o.get_shape().as_list(), [None, 8, 5])
139      model = keras.models.Model([i1, i2], o)
140
141      x1 = np.random.random((2, 4, 5))
142      x2 = np.random.random((2, 4, 5))
143      out = model.predict([x1, x2])
144      self.assertEqual(out.shape, (2, 8, 5))
145      self.assertAllClose(out, np.concatenate([x1, x2], axis=1), atol=1e-4)
146
147      # test masking
148      m1 = keras.layers.Masking()(i1)
149      layer = keras.layers.Concatenate()
150      o = layer([m1, i2])
151      self.assertListEqual(o.get_shape().as_list(), [None, 4, 10])
152      mask = layer.output_mask
153      self.assertListEqual(mask.get_shape().as_list(), [None, 4])
154
155  def test_concatenate_errors(self):
156    i1 = keras.layers.Input(shape=(4, 5))
157    i2 = keras.layers.Input(shape=(3, 5))
158    with self.assertRaisesRegexp(ValueError, 'inputs with matching shapes'):
159      keras.layers.concatenate([i1, i2], axis=-1)
160    with self.assertRaisesRegexp(ValueError, 'called on a list'):
161      keras.layers.concatenate(i1, axis=-1)
162    with self.assertRaisesRegexp(ValueError, 'called on a list'):
163      keras.layers.concatenate([i1], axis=-1)
164
165  def test_merge_dot(self):
166    with self.test_session():
167      i1 = keras.layers.Input(shape=(4,))
168      i2 = keras.layers.Input(shape=(4,))
169      o = keras.layers.dot([i1, i2], axes=1)
170      self.assertListEqual(o.get_shape().as_list(), [None, 1])
171      model = keras.models.Model([i1, i2], o)
172      _ = keras.layers.Dot(axes=1).get_config()
173
174      x1 = np.random.random((2, 4))
175      x2 = np.random.random((2, 4))
176      out = model.predict([x1, x2])
177      self.assertEqual(out.shape, (2, 1))
178      expected = np.zeros((2, 1))
179      expected[0, 0] = np.dot(x1[0], x2[0])
180      expected[1, 0] = np.dot(x1[1], x2[1])
181      self.assertAllClose(out, expected, atol=1e-4)
182
183      # Test with negative tuple of axes.
184      o = keras.layers.dot([i1, i2], axes=(-1, -1))
185      self.assertListEqual(o.get_shape().as_list(), [None, 1])
186      model = keras.models.Model([i1, i2], o)
187      out = model.predict([x1, x2])
188      self.assertEqual(out.shape, (2, 1))
189      self.assertAllClose(out, expected, atol=1e-4)
190
191      # test compute_output_shape
192      layer = keras.layers.Dot(axes=-1)
193      self.assertEqual(layer.compute_output_shape([(4, 5), (4, 5)]), (4, 1))
194
195  def test_dot_errors(self):
196    i1 = keras.layers.Input(shape=(4, 5))
197    i2 = keras.layers.Input(shape=(4, 6))
198    i3 = keras.layers.Input(shape=(4, 6))
199    with self.assertRaises(ValueError):
200      keras.layers.dot([i1, i2], axes=-1)
201    with self.assertRaises(ValueError):
202      keras.layers.dot(i1, axes=-1)
203    with self.assertRaises(ValueError):
204      keras.layers.dot([i1], axes=-1)
205    with self.assertRaises(ValueError):
206      keras.layers.dot([i1, i2, i3], axes=-1)
207    with self.assertRaises(ValueError):
208      dot = keras.layers.Dot(1)
209      dot.compute_output_shape(1)
210
211  def test_merge_subtract(self):
212    i1 = keras.layers.Input(shape=(4, 5))
213    i2 = keras.layers.Input(shape=(4, 5))
214    y = keras.layers.subtract([i1, i2])
215    self.assertEqual(y.get_shape().as_list(), [None, 4, 5])
216
217    # Test invalid use cases
218    i1 = keras.layers.Input(shape=(4, 5))
219    i2 = keras.layers.Input(shape=(3, 5))
220    with self.assertRaises(ValueError):
221      keras.layers.subtract([i1, i2])
222    with self.assertRaises(ValueError):
223      keras.layers.subtract([i1, i1, i1])
224
225
226if __name__ == '__main__':
227  test.main()
228