1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for convolutional recurrent layers."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.python.keras._impl import keras
24from tensorflow.python.keras._impl.keras import testing_utils
25from tensorflow.python.platform import test
26
27
28class ConvLSTMTest(test.TestCase):
29
30  def test_conv_lstm(self):
31    num_row = 3
32    num_col = 3
33    filters = 2
34    num_samples = 1
35    input_channel = 2
36    input_num_row = 5
37    input_num_col = 5
38    sequence_len = 2
39    for data_format in ['channels_first', 'channels_last']:
40      if data_format == 'channels_first':
41        inputs = np.random.rand(num_samples, sequence_len,
42                                input_channel,
43                                input_num_row, input_num_col)
44      else:
45        inputs = np.random.rand(num_samples, sequence_len,
46                                input_num_row, input_num_col,
47                                input_channel)
48
49      for return_sequences in [True, False]:
50        with self.test_session():
51          # test for return state:
52          x = keras.Input(batch_shape=inputs.shape)
53          kwargs = {'data_format': data_format,
54                    'return_sequences': return_sequences,
55                    'return_state': True,
56                    'stateful': True,
57                    'filters': filters,
58                    'kernel_size': (num_row, num_col),
59                    'padding': 'valid'}
60          layer = keras.layers.ConvLSTM2D(**kwargs)
61          layer.build(inputs.shape)
62          outputs = layer(x)
63          _, states = outputs[0], outputs[1:]
64          self.assertEqual(len(states), 2)
65          model = keras.models.Model(x, states[0])
66          state = model.predict(inputs)
67          self.assertAllClose(
68              keras.backend.eval(layer.states[0]), state, atol=1e-4)
69
70          # test for output shape:
71          testing_utils.layer_test(
72              keras.layers.ConvLSTM2D,
73              kwargs={'data_format': data_format,
74                      'return_sequences': return_sequences,
75                      'filters': filters,
76                      'kernel_size': (num_row, num_col),
77                      'padding': 'valid'},
78              input_shape=inputs.shape)
79
80  def test_conv_lstm_statefulness(self):
81    # Tests for statefulness
82    num_row = 3
83    num_col = 3
84    filters = 2
85    num_samples = 1
86    input_channel = 2
87    input_num_row = 5
88    input_num_col = 5
89    sequence_len = 2
90    inputs = np.random.rand(num_samples, sequence_len,
91                            input_num_row, input_num_col,
92                            input_channel)
93
94    with self.test_session():
95      model = keras.models.Sequential()
96      kwargs = {'data_format': 'channels_last',
97                'return_sequences': False,
98                'filters': filters,
99                'kernel_size': (num_row, num_col),
100                'stateful': True,
101                'batch_input_shape': inputs.shape,
102                'padding': 'same'}
103      layer = keras.layers.ConvLSTM2D(**kwargs)
104
105      model.add(layer)
106      model.compile(optimizer='sgd', loss='mse')
107      out1 = model.predict(np.ones_like(inputs))
108
109      # train once so that the states change
110      model.train_on_batch(np.ones_like(inputs),
111                           np.random.random(out1.shape))
112      out2 = model.predict(np.ones_like(inputs))
113
114      # if the state is not reset, output should be different
115      self.assertNotEqual(out1.max(), out2.max())
116
117      # check that output changes after states are reset
118      # (even though the model itself didn't change)
119      layer.reset_states()
120      out3 = model.predict(np.ones_like(inputs))
121      self.assertNotEqual(out3.max(), out2.max())
122
123      # check that container-level reset_states() works
124      model.reset_states()
125      out4 = model.predict(np.ones_like(inputs))
126      self.assertAllClose(out3, out4, atol=1e-5)
127
128      # check that the call to `predict` updated the states
129      out5 = model.predict(np.ones_like(inputs))
130      self.assertNotEqual(out4.max(), out5.max())
131
132  def test_conv_lstm_regularizers(self):
133    # check regularizers
134    num_row = 3
135    num_col = 3
136    filters = 2
137    num_samples = 1
138    input_channel = 2
139    input_num_row = 5
140    input_num_col = 5
141    sequence_len = 2
142    inputs = np.random.rand(num_samples, sequence_len,
143                            input_num_row, input_num_col,
144                            input_channel)
145
146    with self.test_session():
147      kwargs = {'data_format': 'channels_last',
148                'return_sequences': False,
149                'kernel_size': (num_row, num_col),
150                'stateful': True,
151                'filters': filters,
152                'batch_input_shape': inputs.shape,
153                'kernel_regularizer': keras.regularizers.L1L2(l1=0.01),
154                'recurrent_regularizer': keras.regularizers.L1L2(l1=0.01),
155                'activity_regularizer': 'l2',
156                'bias_regularizer': 'l2',
157                'kernel_constraint': 'max_norm',
158                'recurrent_constraint': 'max_norm',
159                'bias_constraint': 'max_norm',
160                'padding': 'same'}
161
162      layer = keras.layers.ConvLSTM2D(**kwargs)
163      layer.build(inputs.shape)
164      self.assertEqual(len(layer.losses), 3)
165      layer(keras.backend.variable(np.ones(inputs.shape)))
166      self.assertEqual(len(layer.losses), 4)
167
168  def test_conv_lstm_dropout(self):
169    # check dropout
170    with self.test_session():
171      testing_utils.layer_test(
172          keras.layers.ConvLSTM2D,
173          kwargs={'data_format': 'channels_last',
174                  'return_sequences': False,
175                  'filters': 2,
176                  'kernel_size': (3, 3),
177                  'padding': 'same',
178                  'dropout': 0.1,
179                  'recurrent_dropout': 0.1},
180          input_shape=(1, 2, 5, 5, 2))
181
182
183if __name__ == '__main__':
184  test.main()
185