1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for merge layers.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23from tensorflow.python.keras._impl import keras 24from tensorflow.python.ops import array_ops 25from tensorflow.python.platform import test 26 27 28class MergeLayersTest(test.TestCase): 29 30 def test_merge_add(self): 31 with self.test_session(): 32 i1 = keras.layers.Input(shape=(4, 5)) 33 i2 = keras.layers.Input(shape=(4, 5)) 34 i3 = keras.layers.Input(shape=(4, 5)) 35 36 o = keras.layers.add([i1, i2, i3]) 37 self.assertListEqual(o.get_shape().as_list(), [None, 4, 5]) 38 model = keras.models.Model([i1, i2, i3], o) 39 40 x1 = np.random.random((2, 4, 5)) 41 x2 = np.random.random((2, 4, 5)) 42 x3 = np.random.random((2, 4, 5)) 43 out = model.predict([x1, x2, x3]) 44 self.assertEqual(out.shape, (2, 4, 5)) 45 self.assertAllClose(out, x1 + x2 + x3, atol=1e-4) 46 47 # test masking 48 i1 = keras.layers.Input(shape=(4, 5)) 49 i2 = keras.layers.Input(shape=(4, 5)) 50 m1 = keras.layers.Masking()(i1) 51 layer = keras.layers.Add() 52 o = layer([m1, i2]) 53 self.assertListEqual(o.get_shape().as_list(), [None, 4, 5]) 54 mask = layer.output_mask 55 self.assertListEqual(mask.get_shape().as_list(), [None, 4]) 56 57 # test missing shape 58 i1 = array_ops.placeholder(shape=(4, None), dtype='float32') 59 i2 = array_ops.placeholder(shape=(4, 5), dtype='float32') 60 layer = keras.layers.Add() 61 o = layer([i1, i2]) 62 63 def test_merge_elementwise_errors(self): 64 i1 = keras.layers.Input(shape=(4, 5)) 65 i2 = keras.layers.Input(shape=(4, 6)) 66 with self.assertRaises(ValueError): 67 keras.layers.add([i1, i2]) 68 with self.assertRaises(ValueError): 69 keras.layers.add([i1]) 70 with self.assertRaises(ValueError): 71 keras.layers.add(i1) 72 with self.assertRaises(ValueError): 73 keras.layers.add([i1]) 74 75 def test_merge_multiply(self): 76 with self.test_session(): 77 i1 = keras.layers.Input(shape=(4, 5)) 78 i2 = keras.layers.Input(shape=(4, 5)) 79 i3 = keras.layers.Input(shape=(4, 5)) 80 o = keras.layers.multiply([i1, i2, i3]) 81 self.assertListEqual(o.get_shape().as_list(), [None, 4, 5]) 82 model = keras.models.Model([i1, i2, i3], o) 83 84 x1 = np.random.random((2, 4, 5)) 85 x2 = np.random.random((2, 4, 5)) 86 x3 = np.random.random((2, 4, 5)) 87 out = model.predict([x1, x2, x3]) 88 self.assertEqual(out.shape, (2, 4, 5)) 89 self.assertAllClose(out, x1 * x2 * x3, atol=1e-4) 90 91 def test_merge_average(self): 92 with self.test_session(): 93 i1 = keras.layers.Input(shape=(4, 5)) 94 i2 = keras.layers.Input(shape=(4, 5)) 95 o = keras.layers.average([i1, i2]) 96 self.assertListEqual(o.get_shape().as_list(), [None, 4, 5]) 97 model = keras.models.Model([i1, i2], o) 98 99 x1 = np.random.random((2, 4, 5)) 100 x2 = np.random.random((2, 4, 5)) 101 out = model.predict([x1, x2]) 102 self.assertEqual(out.shape, (2, 4, 5)) 103 self.assertAllClose(out, 0.5 * (x1 + x2), atol=1e-4) 104 105 def test_merge_maximum(self): 106 with self.test_session(): 107 i1 = keras.layers.Input(shape=(4, 5)) 108 i2 = keras.layers.Input(shape=(4, 5)) 109 o = keras.layers.maximum([i1, i2]) 110 self.assertListEqual(o.get_shape().as_list(), [None, 4, 5]) 111 model = keras.models.Model([i1, i2], o) 112 113 x1 = np.random.random((2, 4, 5)) 114 x2 = np.random.random((2, 4, 5)) 115 out = model.predict([x1, x2]) 116 self.assertEqual(out.shape, (2, 4, 5)) 117 self.assertAllClose(out, np.maximum(x1, x2), atol=1e-4) 118 119 def test_merge_minimum(self): 120 with self.test_session(): 121 i1 = keras.layers.Input(shape=(4, 5)) 122 i2 = keras.layers.Input(shape=(4, 5)) 123 o = keras.layers.minimum([i1, i2]) 124 self.assertListEqual(o.get_shape().as_list(), [None, 4, 5]) 125 model = keras.models.Model([i1, i2], o) 126 127 x1 = np.random.random((2, 4, 5)) 128 x2 = np.random.random((2, 4, 5)) 129 out = model.predict([x1, x2]) 130 self.assertEqual(out.shape, (2, 4, 5)) 131 self.assertAllClose(out, np.minimum(x1, x2), atol=1e-4) 132 133 def test_merge_concatenate(self): 134 with self.test_session(): 135 i1 = keras.layers.Input(shape=(4, 5)) 136 i2 = keras.layers.Input(shape=(4, 5)) 137 o = keras.layers.concatenate([i1, i2], axis=1) 138 self.assertListEqual(o.get_shape().as_list(), [None, 8, 5]) 139 model = keras.models.Model([i1, i2], o) 140 141 x1 = np.random.random((2, 4, 5)) 142 x2 = np.random.random((2, 4, 5)) 143 out = model.predict([x1, x2]) 144 self.assertEqual(out.shape, (2, 8, 5)) 145 self.assertAllClose(out, np.concatenate([x1, x2], axis=1), atol=1e-4) 146 147 # test masking 148 m1 = keras.layers.Masking()(i1) 149 layer = keras.layers.Concatenate() 150 o = layer([m1, i2]) 151 self.assertListEqual(o.get_shape().as_list(), [None, 4, 10]) 152 mask = layer.output_mask 153 self.assertListEqual(mask.get_shape().as_list(), [None, 4]) 154 155 def test_concatenate_errors(self): 156 i1 = keras.layers.Input(shape=(4, 5)) 157 i2 = keras.layers.Input(shape=(3, 5)) 158 with self.assertRaisesRegexp(ValueError, 'inputs with matching shapes'): 159 keras.layers.concatenate([i1, i2], axis=-1) 160 with self.assertRaisesRegexp(ValueError, 'called on a list'): 161 keras.layers.concatenate(i1, axis=-1) 162 with self.assertRaisesRegexp(ValueError, 'called on a list'): 163 keras.layers.concatenate([i1], axis=-1) 164 165 def test_merge_dot(self): 166 with self.test_session(): 167 i1 = keras.layers.Input(shape=(4,)) 168 i2 = keras.layers.Input(shape=(4,)) 169 o = keras.layers.dot([i1, i2], axes=1) 170 self.assertListEqual(o.get_shape().as_list(), [None, 1]) 171 model = keras.models.Model([i1, i2], o) 172 _ = keras.layers.Dot(axes=1).get_config() 173 174 x1 = np.random.random((2, 4)) 175 x2 = np.random.random((2, 4)) 176 out = model.predict([x1, x2]) 177 self.assertEqual(out.shape, (2, 1)) 178 expected = np.zeros((2, 1)) 179 expected[0, 0] = np.dot(x1[0], x2[0]) 180 expected[1, 0] = np.dot(x1[1], x2[1]) 181 self.assertAllClose(out, expected, atol=1e-4) 182 183 # Test with negative tuple of axes. 184 o = keras.layers.dot([i1, i2], axes=(-1, -1)) 185 self.assertListEqual(o.get_shape().as_list(), [None, 1]) 186 model = keras.models.Model([i1, i2], o) 187 out = model.predict([x1, x2]) 188 self.assertEqual(out.shape, (2, 1)) 189 self.assertAllClose(out, expected, atol=1e-4) 190 191 # test compute_output_shape 192 layer = keras.layers.Dot(axes=-1) 193 self.assertEqual(layer.compute_output_shape([(4, 5), (4, 5)]), (4, 1)) 194 195 def test_dot_errors(self): 196 i1 = keras.layers.Input(shape=(4, 5)) 197 i2 = keras.layers.Input(shape=(4, 6)) 198 i3 = keras.layers.Input(shape=(4, 6)) 199 with self.assertRaises(ValueError): 200 keras.layers.dot([i1, i2], axes=-1) 201 with self.assertRaises(ValueError): 202 keras.layers.dot(i1, axes=-1) 203 with self.assertRaises(ValueError): 204 keras.layers.dot([i1], axes=-1) 205 with self.assertRaises(ValueError): 206 keras.layers.dot([i1, i2, i3], axes=-1) 207 with self.assertRaises(ValueError): 208 dot = keras.layers.Dot(1) 209 dot.compute_output_shape(1) 210 211 def test_merge_subtract(self): 212 i1 = keras.layers.Input(shape=(4, 5)) 213 i2 = keras.layers.Input(shape=(4, 5)) 214 y = keras.layers.subtract([i1, i2]) 215 self.assertEqual(y.get_shape().as_list(), [None, 4, 5]) 216 217 # Test invalid use cases 218 i1 = keras.layers.Input(shape=(4, 5)) 219 i2 = keras.layers.Input(shape=(3, 5)) 220 with self.assertRaises(ValueError): 221 keras.layers.subtract([i1, i2]) 222 with self.assertRaises(ValueError): 223 keras.layers.subtract([i1, i1, i1]) 224 225 226if __name__ == '__main__': 227 test.main() 228