1f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet#
3f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet# Licensed under the Apache License, Version 2.0 (the "License");
4f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet# you may not use this file except in compliance with the License.
5f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet# You may obtain a copy of the License at
6f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet#
7f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet#     http://www.apache.org/licenses/LICENSE-2.0
8f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet#
9f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet# Unless required by applicable law or agreed to in writing, software
10f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet# distributed under the License is distributed on an "AS IS" BASIS,
11f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet# See the License for the specific language governing permissions and
13f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet# limitations under the License.
14f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet# ==============================================================================
15f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet"""Tests for GRU layer."""
16f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
17f49f801276154d0f693c5d57db6977a7eb32f017Francois Cholletfrom __future__ import absolute_import
18f49f801276154d0f693c5d57db6977a7eb32f017Francois Cholletfrom __future__ import division
19f49f801276154d0f693c5d57db6977a7eb32f017Francois Cholletfrom __future__ import print_function
20f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
21f49f801276154d0f693c5d57db6977a7eb32f017Francois Cholletimport numpy as np
22f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
23eaaa0b93852054dee086a3ed5373cf8bbe3d2fb3Francois Cholletfrom tensorflow.python.keras._impl import keras
24eaaa0b93852054dee086a3ed5373cf8bbe3d2fb3Francois Cholletfrom tensorflow.python.keras._impl.keras import testing_utils
25f49f801276154d0f693c5d57db6977a7eb32f017Francois Cholletfrom tensorflow.python.platform import test
26f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
27f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
28f49f801276154d0f693c5d57db6977a7eb32f017Francois Cholletclass GRULayerTest(test.TestCase):
29f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
30f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet  def test_return_sequences_GRU(self):
31f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    num_samples = 2
32f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    timesteps = 3
33f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    embedding_dim = 4
34f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    units = 2
35f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    with self.test_session():
36f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      testing_utils.layer_test(
37f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet          keras.layers.GRU,
38f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet          kwargs={'units': units,
39f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet                  'return_sequences': True},
40f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet          input_shape=(num_samples, timesteps, embedding_dim))
41f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
42f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet  def test_dynamic_behavior_GRU(self):
43f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    num_samples = 2
44f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    timesteps = 3
45f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    embedding_dim = 4
46f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    units = 2
47f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    with self.test_session():
48f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      layer = keras.layers.GRU(units, input_shape=(None, embedding_dim))
49f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      model = keras.models.Sequential()
50f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      model.add(layer)
51f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      model.compile('sgd', 'mse')
52f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      x = np.random.random((num_samples, timesteps, embedding_dim))
53f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      y = np.random.random((num_samples, units))
54f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      model.train_on_batch(x, y)
55f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
56f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet  def test_dropout_GRU(self):
57f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    num_samples = 2
58f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    timesteps = 3
59f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    embedding_dim = 4
60f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    units = 2
61f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    with self.test_session():
62f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      testing_utils.layer_test(
63f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet          keras.layers.GRU,
64f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet          kwargs={'units': units,
65f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet                  'dropout': 0.1,
66f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet                  'recurrent_dropout': 0.1},
67f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet          input_shape=(num_samples, timesteps, embedding_dim))
68f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
69f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet  def test_implementation_mode_GRU(self):
70f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    num_samples = 2
71f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    timesteps = 3
72f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    embedding_dim = 4
73f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    units = 2
74f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    with self.test_session():
75f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      for mode in [0, 1, 2]:
76f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet        testing_utils.layer_test(
77f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet            keras.layers.GRU,
78f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet            kwargs={'units': units,
79f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet                    'implementation': mode},
80f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet            input_shape=(num_samples, timesteps, embedding_dim))
81f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
82f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet  def test_statefulness_GRU(self):
83f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    num_samples = 2
84f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    timesteps = 3
85f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    embedding_dim = 4
86f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    units = 2
87f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    layer_class = keras.layers.GRU
88f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    with self.test_session():
89f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      model = keras.models.Sequential()
90f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      model.add(
91f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet          keras.layers.Embedding(
92f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet              4,
93f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet              embedding_dim,
94f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet              mask_zero=True,
95f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet              input_length=timesteps,
96f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet              batch_input_shape=(num_samples, timesteps)))
97f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      layer = layer_class(
98f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet          units, return_sequences=False, stateful=True, weights=None)
99f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      model.add(layer)
100f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      model.compile(optimizer='sgd', loss='mse')
101f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      out1 = model.predict(np.ones((num_samples, timesteps)))
102f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      self.assertEqual(out1.shape, (num_samples, units))
103f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
104f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      # train once so that the states change
105f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      model.train_on_batch(
106f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet          np.ones((num_samples, timesteps)), np.ones((num_samples, units)))
107f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      out2 = model.predict(np.ones((num_samples, timesteps)))
108f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
109f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      # if the state is not reset, output should be different
110f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      self.assertNotEqual(out1.max(), out2.max())
111f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
112f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      # check that output changes after states are reset
113f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      # (even though the model itself didn't change)
114f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      layer.reset_states()
115f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      out3 = model.predict(np.ones((num_samples, timesteps)))
116f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      self.assertNotEqual(out2.max(), out3.max())
117f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
118f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      # check that container-level reset_states() works
119f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      model.reset_states()
120f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      out4 = model.predict(np.ones((num_samples, timesteps)))
121f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      np.testing.assert_allclose(out3, out4, atol=1e-5)
122f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
123f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      # check that the call to `predict` updated the states
124f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      out5 = model.predict(np.ones((num_samples, timesteps)))
125f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      self.assertNotEqual(out4.max(), out5.max())
126f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
127f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      # Check masking
128f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      layer.reset_states()
129f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
130f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      left_padded_input = np.ones((num_samples, timesteps))
131f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      left_padded_input[0, :1] = 0
132f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      left_padded_input[1, :2] = 0
133f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      out6 = model.predict(left_padded_input)
134f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
135f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      layer.reset_states()
136f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
137f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      right_padded_input = np.ones((num_samples, timesteps))
138f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      right_padded_input[0, -1:] = 0
139f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      right_padded_input[1, -2:] = 0
140f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      out7 = model.predict(right_padded_input)
141f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
142f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      np.testing.assert_allclose(out7, out6, atol=1e-5)
143f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
144b5c59851f79553c8a29166ca1b384767617b8d78Francois Chollet  def test_regularizers_GRU(self):
145f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    embedding_dim = 4
146f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    layer_class = keras.layers.GRU
147f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    with self.test_session():
148f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      layer = layer_class(
149f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet          5,
150f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet          return_sequences=False,
151f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet          weights=None,
152f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet          input_shape=(None, embedding_dim),
153f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet          kernel_regularizer=keras.regularizers.l1(0.01),
154f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet          recurrent_regularizer=keras.regularizers.l1(0.01),
155f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet          bias_regularizer='l2',
156f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet          activity_regularizer='l1')
157f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      layer.build((None, None, 2))
158f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      self.assertEqual(len(layer.losses), 3)
1593aaa1d3a0dcc52dd83aa9f1ad308e9da47556583Francois Chollet
1603aaa1d3a0dcc52dd83aa9f1ad308e9da47556583Francois Chollet      x = keras.backend.variable(np.ones((2, 3, 2)))
1613aaa1d3a0dcc52dd83aa9f1ad308e9da47556583Francois Chollet      layer(x)
1623aaa1d3a0dcc52dd83aa9f1ad308e9da47556583Francois Chollet      self.assertEqual(len(layer.get_losses_for(x)), 1)
163f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
164b5c59851f79553c8a29166ca1b384767617b8d78Francois Chollet  def test_constraints_GRU(self):
165b5c59851f79553c8a29166ca1b384767617b8d78Francois Chollet    embedding_dim = 4
166b5c59851f79553c8a29166ca1b384767617b8d78Francois Chollet    layer_class = keras.layers.GRU
167b5c59851f79553c8a29166ca1b384767617b8d78Francois Chollet    with self.test_session():
168b5c59851f79553c8a29166ca1b384767617b8d78Francois Chollet      k_constraint = keras.constraints.max_norm(0.01)
169b5c59851f79553c8a29166ca1b384767617b8d78Francois Chollet      r_constraint = keras.constraints.max_norm(0.01)
170b5c59851f79553c8a29166ca1b384767617b8d78Francois Chollet      b_constraint = keras.constraints.max_norm(0.01)
171f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      layer = layer_class(
172f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet          5,
173f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet          return_sequences=False,
174f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet          weights=None,
175f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet          input_shape=(None, embedding_dim),
176b5c59851f79553c8a29166ca1b384767617b8d78Francois Chollet          kernel_constraint=k_constraint,
177b5c59851f79553c8a29166ca1b384767617b8d78Francois Chollet          recurrent_constraint=r_constraint,
178b5c59851f79553c8a29166ca1b384767617b8d78Francois Chollet          bias_constraint=b_constraint)
179f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      layer.build((None, None, embedding_dim))
1803aaa1d3a0dcc52dd83aa9f1ad308e9da47556583Francois Chollet      self.assertEqual(layer.cell.kernel.constraint, k_constraint)
1813aaa1d3a0dcc52dd83aa9f1ad308e9da47556583Francois Chollet      self.assertEqual(layer.cell.recurrent_kernel.constraint, r_constraint)
1823aaa1d3a0dcc52dd83aa9f1ad308e9da47556583Francois Chollet      self.assertEqual(layer.cell.bias.constraint, b_constraint)
183f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
184f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet  def test_with_masking_layer_GRU(self):
185f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    layer_class = keras.layers.GRU
186f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    with self.test_session():
187f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      inputs = np.random.random((2, 3, 4))
188f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      targets = np.abs(np.random.random((2, 3, 5)))
189f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      targets /= targets.sum(axis=-1, keepdims=True)
190f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      model = keras.models.Sequential()
191f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      model.add(keras.layers.Masking(input_shape=(3, 4)))
192f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      model.add(layer_class(units=5, return_sequences=True, unroll=False))
193f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      model.compile(loss='categorical_crossentropy', optimizer='adam')
194f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      model.fit(inputs, targets, epochs=1, batch_size=2, verbose=1)
195f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
196f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet  def test_from_config_GRU(self):
197f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    layer_class = keras.layers.GRU
198f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    for stateful in (False, True):
199f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      l1 = layer_class(units=1, stateful=stateful)
200f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      l2 = layer_class.from_config(l1.get_config())
201f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      assert l1.get_config() == l2.get_config()
202f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
203f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
204f49f801276154d0f693c5d57db6977a7eb32f017Francois Cholletif __name__ == '__main__':
205f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet  test.main()
206