1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for rmsprop."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import copy
22import itertools
23import math
24
25import numpy as np
26
27from tensorflow.python.framework import constant_op
28from tensorflow.python.framework import dtypes
29from tensorflow.python.framework import ops
30from tensorflow.python.ops import embedding_ops
31from tensorflow.python.ops import math_ops
32from tensorflow.python.ops import resource_variable_ops
33from tensorflow.python.ops import variables
34from tensorflow.python.platform import test
35from tensorflow.python.training import rmsprop
36
37_DATA_TYPES = [dtypes.half, dtypes.float32]
38
39_TEST_PARAM_VALUES = [
40    # learning_rate, decay, momentum, epsilon, centered, use_resource
41    [0.5, 0.9, 0.0, 1e-3, True, False],
42    [0.5, 0.9, 0.0, 1e-3, False, False],
43    [0.5, 0.9, 0.0, 1e-3, True, True],
44    [0.5, 0.9, 0.0, 1e-3, False, True],
45    [0.1, 0.9, 0.0, 1e-3, True, False],
46    [0.5, 0.95, 0.0, 1e-3, False, False],
47    [0.5, 0.95, 0.0, 1e-5, True, False],
48    [0.5, 0.95, 0.9, 1e-5, True, False],
49]
50
51_TESTPARAMS = [
52    [data_type] + values
53    for data_type, values in itertools.product(_DATA_TYPES, _TEST_PARAM_VALUES)
54]
55
56
57class RMSPropOptimizerTest(test.TestCase):
58
59  def _rmsprop_update_numpy(self, var, g, mg, rms, mom, lr, decay, momentum,
60                            epsilon, centered):
61    rms_t = rms * decay + (1 - decay) * g * g
62    denom_t = rms_t + epsilon
63    if centered:
64      mg_t = mg * decay + (1 - decay) * g
65      denom_t -= mg_t * mg_t
66    else:
67      mg_t = mg
68    mom_t = momentum * mom + lr * g / np.sqrt(denom_t, dtype=denom_t.dtype)
69    var_t = var - mom_t
70    return var_t, mg_t, rms_t, mom_t
71
72  def _sparse_rmsprop_update_numpy(self, var, gindexs, gvalues, mg, rms, mom,
73                                   lr, decay, momentum, epsilon, centered):
74    mg_t = copy.deepcopy(mg)
75    rms_t = copy.deepcopy(rms)
76    mom_t = copy.deepcopy(mom)
77    var_t = copy.deepcopy(var)
78    for i in range(len(gindexs)):
79      gindex = gindexs[i]
80      gvalue = gvalues[i]
81      rms_t[gindex] = rms[gindex] * decay + (1 - decay) * gvalue * gvalue
82      denom_t = rms_t[gindex] + epsilon
83      if centered:
84        mg_t[gindex] = mg_t[gindex] * decay + (1 - decay) * gvalue
85        denom_t -= mg_t[gindex] * mg_t[gindex]
86      mom_t[gindex] = momentum * mom[gindex] + lr * gvalue / np.sqrt(denom_t)
87      var_t[gindex] = var[gindex] - mom_t[gindex]
88    return var_t, mg_t, rms_t, mom_t
89
90  def testDense(self):
91    # TODO(yori): Use ParameterizedTest when available
92    for (dtype, learning_rate, decay, momentum,
93         epsilon, centered, use_resource) in _TESTPARAMS:
94      with self.test_session(use_gpu=True):
95        # Initialize variables for numpy implementation.
96        var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
97        grads0_np = np.array([0.1, 0.2], dtype=dtype.as_numpy_dtype)
98        var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
99        grads1_np = np.array([0.01, 0.2], dtype=dtype.as_numpy_dtype)
100
101        if use_resource:
102          var0 = resource_variable_ops.ResourceVariable(var0_np)
103          var1 = resource_variable_ops.ResourceVariable(var1_np)
104        else:
105          var0 = variables.Variable(var0_np)
106          var1 = variables.Variable(var1_np)
107        grads0 = constant_op.constant(grads0_np)
108        grads1 = constant_op.constant(grads1_np)
109        opt = rmsprop.RMSPropOptimizer(
110            learning_rate=learning_rate,
111            decay=decay,
112            momentum=momentum,
113            epsilon=epsilon,
114            centered=centered)
115
116        update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
117        variables.global_variables_initializer().run()
118
119        mg0 = opt.get_slot(var0, "mg")
120        self.assertEqual(mg0 is not None, centered)
121        mg1 = opt.get_slot(var1, "mg")
122        self.assertEqual(mg1 is not None, centered)
123        rms0 = opt.get_slot(var0, "rms")
124        self.assertTrue(rms0 is not None)
125        rms1 = opt.get_slot(var1, "rms")
126        self.assertTrue(rms1 is not None)
127        mom0 = opt.get_slot(var0, "momentum")
128        self.assertTrue(mom0 is not None)
129        mom1 = opt.get_slot(var1, "momentum")
130        self.assertTrue(mom1 is not None)
131
132        mg0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
133        mg1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
134        rms0_np = np.array([1.0, 1.0], dtype=dtype.as_numpy_dtype)
135        rms1_np = np.array([1.0, 1.0], dtype=dtype.as_numpy_dtype)
136        mom0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
137        mom1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
138
139        # Fetch params to validate initial values
140        self.assertAllClose([1.0, 2.0], var0.eval())
141        self.assertAllClose([3.0, 4.0], var1.eval())
142
143        # Run 4 steps of RMSProp
144        for t in range(1, 5):
145          update.run()
146
147          var0_np, mg0_np, rms0_np, mom0_np = self._rmsprop_update_numpy(
148              var0_np, grads0_np, mg0_np, rms0_np, mom0_np, learning_rate,
149              decay, momentum, epsilon, centered)
150          var1_np, mg1_np, rms1_np, mom1_np = self._rmsprop_update_numpy(
151              var1_np, grads1_np, mg1_np, rms1_np, mom1_np, learning_rate,
152              decay, momentum, epsilon, centered)
153
154          # Validate updated params
155          if centered:
156            self.assertAllCloseAccordingToType(mg0_np, mg0.eval())
157            self.assertAllCloseAccordingToType(mg1_np, mg1.eval())
158          self.assertAllCloseAccordingToType(rms0_np, rms0.eval())
159          self.assertAllCloseAccordingToType(rms1_np, rms1.eval())
160          self.assertAllCloseAccordingToType(mom0_np, mom0.eval())
161          self.assertAllCloseAccordingToType(mom1_np, mom1.eval())
162          self.assertAllCloseAccordingToType(var0_np, var0.eval())
163          self.assertAllCloseAccordingToType(var1_np, var1.eval())
164
165  def testMinimizeSparseResourceVariable(self):
166    for dtype in [dtypes.float32, dtypes.float64]:
167      with self.test_session():
168        var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
169        x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
170        pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x)
171        loss = pred * pred
172        sgd_op = rmsprop.RMSPropOptimizer(
173            learning_rate=1.0,
174            decay=0.0,
175            momentum=0.0,
176            epsilon=0.0,
177            centered=False).minimize(loss)
178        variables.global_variables_initializer().run()
179        # Fetch params to validate initial values
180        self.assertAllCloseAccordingToType([[1.0, 2.0]], var0.eval())
181        # Run 1 step of sgd
182        sgd_op.run()
183        # Validate updated params
184        self.assertAllCloseAccordingToType(
185            [[0., 1.]], var0.eval(), atol=0.01)
186
187  def testMinimizeSparseResourceVariableCentered(self):
188    for dtype in [dtypes.float32, dtypes.float64]:
189      with self.test_session():
190        var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
191        x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
192        pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x)
193        loss = pred * pred
194        sgd_op = rmsprop.RMSPropOptimizer(
195            learning_rate=1.0,
196            decay=0.0,
197            momentum=0.0,
198            epsilon=1.0,
199            centered=True).minimize(loss)
200        variables.global_variables_initializer().run()
201        # Fetch params to validate initial values
202        self.assertAllCloseAccordingToType([[1.0, 2.0]], var0.eval())
203        # Run 1 step of sgd
204        sgd_op.run()
205        # Validate updated params
206        self.assertAllCloseAccordingToType(
207            [[-111, -138]], var0.eval(), atol=0.01)
208
209  def testSparse(self):
210    # TODO(yori): Use ParameterizedTest when available
211    for (dtype, learning_rate, decay,
212         momentum, epsilon, centered, _) in _TESTPARAMS:
213      with self.test_session(use_gpu=True):
214        # Initialize variables for numpy implementation.
215        var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
216        grads0_np = np.array([0.1], dtype=dtype.as_numpy_dtype)
217        var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
218        grads1_np = np.array([0.01], dtype=dtype.as_numpy_dtype)
219
220        var0 = variables.Variable(var0_np)
221        var1 = variables.Variable(var1_np)
222        grads0_np_indices = np.array([0], dtype=np.int32)
223        grads0 = ops.IndexedSlices(
224            constant_op.constant(grads0_np),
225            constant_op.constant(grads0_np_indices), constant_op.constant([1]))
226        grads1_np_indices = np.array([1], dtype=np.int32)
227        grads1 = ops.IndexedSlices(
228            constant_op.constant(grads1_np),
229            constant_op.constant(grads1_np_indices), constant_op.constant([1]))
230        opt = rmsprop.RMSPropOptimizer(
231            learning_rate=learning_rate,
232            decay=decay,
233            momentum=momentum,
234            epsilon=epsilon,
235            centered=centered)
236        update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
237        variables.global_variables_initializer().run()
238
239        mg0 = opt.get_slot(var0, "mg")
240        self.assertEqual(mg0 is not None, centered)
241        mg1 = opt.get_slot(var1, "mg")
242        self.assertEqual(mg1 is not None, centered)
243        rms0 = opt.get_slot(var0, "rms")
244        self.assertTrue(rms0 is not None)
245        rms1 = opt.get_slot(var1, "rms")
246        self.assertTrue(rms1 is not None)
247        mom0 = opt.get_slot(var0, "momentum")
248        self.assertTrue(mom0 is not None)
249        mom1 = opt.get_slot(var1, "momentum")
250        self.assertTrue(mom1 is not None)
251
252        mg0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
253        mg1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
254        rms0_np = np.array([1.0, 1.0], dtype=dtype.as_numpy_dtype)
255        rms1_np = np.array([1.0, 1.0], dtype=dtype.as_numpy_dtype)
256        mom0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
257        mom1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
258
259        # Fetch params to validate initial values
260        self.assertAllClose([1.0, 2.0], var0.eval())
261        self.assertAllClose([3.0, 4.0], var1.eval())
262
263        # Run 4 steps of RMSProp
264        for t in range(1, 5):
265          update.run()
266
267          var0_np, mg0_np, rms0_np, mom0_np = self._sparse_rmsprop_update_numpy(
268              var0_np, grads0_np_indices, grads0_np, mg0_np, rms0_np, mom0_np,
269              learning_rate, decay, momentum, epsilon, centered)
270          var1_np, mg1_np, rms1_np, mom1_np = self._sparse_rmsprop_update_numpy(
271              var1_np, grads1_np_indices, grads1_np, mg1_np, rms1_np, mom1_np,
272              learning_rate, decay, momentum, epsilon, centered)
273
274          # Validate updated params
275          if centered:
276            self.assertAllCloseAccordingToType(mg0_np, mg0.eval())
277            self.assertAllCloseAccordingToType(mg1_np, mg1.eval())
278          self.assertAllCloseAccordingToType(rms0_np, rms0.eval())
279          self.assertAllCloseAccordingToType(rms1_np, rms1.eval())
280          self.assertAllCloseAccordingToType(mom0_np, mom0.eval())
281          self.assertAllCloseAccordingToType(mom1_np, mom1.eval())
282          self.assertAllCloseAccordingToType(var0_np, var0.eval())
283          self.assertAllCloseAccordingToType(var1_np, var1.eval())
284
285  def testWithoutMomentum(self):
286    for dtype in [dtypes.half, dtypes.float32]:
287      with self.test_session(use_gpu=True):
288        var0 = variables.Variable([1.0, 2.0], dtype=dtype)
289        var1 = variables.Variable([3.0, 4.0], dtype=dtype)
290        grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
291        grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
292        opt = rmsprop.RMSPropOptimizer(
293            learning_rate=2.0, decay=0.9, momentum=0.0, epsilon=1.0)
294        update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
295        variables.global_variables_initializer().run()
296
297        rms0 = opt.get_slot(var0, "rms")
298        self.assertTrue(rms0 is not None)
299        rms1 = opt.get_slot(var1, "rms")
300        self.assertTrue(rms1 is not None)
301        mom0 = opt.get_slot(var0, "momentum")
302        self.assertTrue(mom0 is not None)
303        mom1 = opt.get_slot(var1, "momentum")
304        self.assertTrue(mom1 is not None)
305
306        # Fetch params to validate initial values
307        self.assertAllClose([1.0, 2.0], var0.eval())
308        self.assertAllClose([3.0, 4.0], var1.eval())
309        # Step 1: the rms accumulators where 1. So we should see a normal
310        # update: v -= grad * learning_rate
311        update.run()
312        # Check the root mean square accumulators.
313        self.assertAllCloseAccordingToType(
314            np.array([0.901, 0.901]), rms0.eval())
315        self.assertAllCloseAccordingToType(
316            np.array([0.90001, 0.90001]), rms1.eval())
317        # Check the parameters.
318        self.assertAllCloseAccordingToType(
319            np.array([
320                1.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1.0)),
321                2.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1.0))
322            ]), var0.eval())
323        self.assertAllCloseAccordingToType(
324            np.array([
325                3.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1.0)),
326                4.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1.0))
327            ]), var1.eval())
328        # Step 2: the root mean square accumulators contain the previous update.
329        update.run()
330        # Check the rms accumulators.
331        self.assertAllCloseAccordingToType(
332            np.array([0.901 * 0.9 + 0.001, 0.901 * 0.9 + 0.001]), rms0.eval())
333        self.assertAllCloseAccordingToType(
334            np.array([0.90001 * 0.9 + 1e-5, 0.90001 * 0.9 + 1e-5]), rms1.eval())
335        # Check the parameters.
336        self.assertAllCloseAccordingToType(
337            np.array([
338                1.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1.0)) -
339                (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1.0)),
340                2.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1.0)) -
341                (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1.0))
342            ]), var0.eval())
343        self.assertAllCloseAccordingToType(
344            np.array([
345                3.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1.0)) -
346                (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 1e-5 + 1.0)),
347                4.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1.0)) -
348                (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 1e-5 + 1.0))
349            ]), var1.eval())
350
351  def testWithMomentum(self):
352    for dtype in [dtypes.half, dtypes.float32]:
353      with self.test_session(use_gpu=True):
354        var0 = variables.Variable([1.0, 2.0], dtype=dtype)
355        var1 = variables.Variable([3.0, 4.0], dtype=dtype)
356        grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
357        grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
358
359        opt = rmsprop.RMSPropOptimizer(
360            learning_rate=2.0, decay=0.9, momentum=0.5, epsilon=1e-5)
361        update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
362        variables.global_variables_initializer().run()
363
364        rms0 = opt.get_slot(var0, "rms")
365        self.assertTrue(rms0 is not None)
366        rms1 = opt.get_slot(var1, "rms")
367        self.assertTrue(rms1 is not None)
368        mom0 = opt.get_slot(var0, "momentum")
369        self.assertTrue(mom0 is not None)
370        mom1 = opt.get_slot(var1, "momentum")
371        self.assertTrue(mom1 is not None)
372
373        # Fetch params to validate initial values
374        self.assertAllClose([1.0, 2.0], var0.eval())
375        self.assertAllClose([3.0, 4.0], var1.eval())
376        # Step 1: rms = 1, mom = 0. So we should see a normal
377        # update: v -= grad * learning_rate
378        update.run()
379        # Check the root mean square accumulators.
380        self.assertAllCloseAccordingToType(
381            np.array([0.901, 0.901]), rms0.eval())
382        self.assertAllCloseAccordingToType(
383            np.array([0.90001, 0.90001]), rms1.eval())
384        # Check the momentum accumulators
385        self.assertAllCloseAccordingToType(
386            np.array([(0.1 * 2.0 / math.sqrt(0.901 + 1e-5)),
387                      (0.1 * 2.0 / math.sqrt(0.901 + 1e-5))]), mom0.eval())
388        self.assertAllCloseAccordingToType(
389            np.array([(0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)),
390                      (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5))]), mom1.eval())
391
392        # Check that the parameters.
393        self.assertAllCloseAccordingToType(
394            np.array([
395                1.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)),
396                2.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1e-5))
397            ]), var0.eval())
398        self.assertAllCloseAccordingToType(
399            np.array([
400                3.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)),
401                4.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5))
402            ]), var1.eval())
403
404        # Step 2: the root mean square accumulators contain the previous update.
405        update.run()
406        # Check the rms accumulators.
407        self.assertAllCloseAccordingToType(
408            np.array([0.901 * 0.9 + 0.001, 0.901 * 0.9 + 0.001]), rms0.eval())
409        self.assertAllCloseAccordingToType(
410            np.array([0.90001 * 0.9 + 1e-5, 0.90001 * 0.9 + 1e-5]), rms1.eval())
411        self.assertAllCloseAccordingToType(
412            np.array([
413                0.5 * (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) +
414                (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1e-5)),
415                0.5 * (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) +
416                (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1e-5))
417            ]), mom0.eval())
418        self.assertAllCloseAccordingToType(
419            np.array([
420                0.5 * (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) +
421                (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 2e-5)),
422                0.5 * (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) +
423                (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 2e-5))
424            ]), mom1.eval())
425
426        # Check the parameters.
427        self.assertAllCloseAccordingToType(
428            np.array([
429                1.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) -
430                (0.5 * (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) +
431                 (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1e-5))),
432                2.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) -
433                (0.5 * (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) +
434                 (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1e-5)))
435            ]), var0.eval())
436
437        self.assertAllCloseAccordingToType(
438            np.array([
439                3.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) -
440                (0.5 * (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) +
441                 (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 2e-5))),
442                4.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) -
443                (0.5 * (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) +
444                 (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 2e-5)))
445            ]), var1.eval())
446
447
448if __name__ == "__main__":
449  test.main()
450