1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Tests for tensorflow.python.training.saver.py."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import contextlib
22import math
23import os
24import random
25import shutil
26import tempfile
27import time
28
29import numpy as np
30import six
31
32from google.protobuf.any_pb2 import Any
33from google.protobuf import text_format
34
35from tensorflow.core.protobuf import config_pb2
36from tensorflow.core.protobuf import meta_graph_pb2
37from tensorflow.core.protobuf import queue_runner_pb2
38from tensorflow.core.protobuf import saver_pb2
39from tensorflow.python import pywrap_tensorflow
40from tensorflow.python.client import session
41from tensorflow.python.data.ops import dataset_ops
42from tensorflow.python.eager import context
43from tensorflow.python.framework import constant_op
44from tensorflow.python.framework import dtypes
45from tensorflow.python.framework import errors
46from tensorflow.python.framework import errors_impl
47from tensorflow.python.framework import function
48from tensorflow.python.framework import graph_io
49from tensorflow.python.framework import meta_graph
50from tensorflow.python.framework import ops as ops_lib
51from tensorflow.python.framework import test_util
52from tensorflow.python.lib.io import file_io
53from tensorflow.python.ops import array_ops
54from tensorflow.python.ops import control_flow_ops
55from tensorflow.python.ops import data_flow_ops
56from tensorflow.python.ops import math_ops
57from tensorflow.python.ops import nn_ops
58from tensorflow.python.ops import partitioned_variables
59from tensorflow.python.ops import random_ops
60from tensorflow.python.ops import resource_variable_ops
61from tensorflow.python.ops import sparse_ops
62from tensorflow.python.ops import variable_scope
63from tensorflow.python.ops import variables
64import tensorflow.python.ops.nn_grad  # pylint: disable=unused-import
65from tensorflow.python.platform import gfile
66from tensorflow.python.platform import test
67from tensorflow.python.summary import summary
68from tensorflow.python.training import adam
69from tensorflow.python.training import gradient_descent
70from tensorflow.python.training import queue_runner_impl
71from tensorflow.python.training import saver as saver_module
72from tensorflow.python.training import saver_test_utils
73from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
74from tensorflow.python.util import compat
75
76
77@test_util.with_c_api
78class SaverTest(test.TestCase):
79
80  def basicSaveRestore(self, variable_op):
81    save_path = os.path.join(self.get_temp_dir(), "basic_save_restore")
82
83    with self.test_session(graph=ops_lib.Graph()) as sess:
84      # Build a graph with 2 parameter nodes, and Save and
85      # Restore nodes for them.
86      v0 = variable_op(10.0, name="v0")
87      v1 = variable_op(20.0, name="v1")
88      v2 = saver_test_utils.CheckpointedOp(name="v2")
89      v2_init = v2.insert("k1", 30.0)
90
91      # Initialize all variables
92      if context.in_graph_mode():
93        self.evaluate([variables.global_variables_initializer(), v2_init])
94
95        # Check that the parameter nodes have been initialized.
96      self.assertEqual(10.0, self.evaluate(v0))
97      self.assertEqual(20.0, self.evaluate(v1))
98      self.assertEqual(b"k1", self.evaluate(v2.keys()))
99      self.assertEqual(30.0, self.evaluate(v2.values()))
100
101      # Save the initialized values in the file at "save_path"
102      save = saver_module.Saver(
103          {
104              "v0": v0,
105              "v1": v1,
106              "v2": v2.saveable
107          }, restore_sequentially=True)
108      val = save.save(sess, save_path)
109      self.assertTrue(isinstance(val, six.string_types))
110      self.assertEqual(save_path, val)
111
112    # Start a second session.  In that session the parameter nodes
113    # have not been initialized either.
114    with self.test_session(graph=ops_lib.Graph()) as sess:
115      v0 = variable_op(-1.0, name="v0")
116      v1 = variable_op(-1.0, name="v1")
117      v2 = saver_test_utils.CheckpointedOp(name="v2")
118
119      # Assert that the variables are not initialized.
120      if context.in_graph_mode():
121        self.assertEqual(
122            len(variables.report_uninitialized_variables().eval()), 2)
123        self.assertEqual(0, len(v2.keys().eval()))
124        self.assertEqual(0, len(v2.values().eval()))
125      # Restore the saved values in the parameter nodes.
126      save = saver_module.Saver({"v0": v0, "v1": v1, "v2": v2.saveable})
127      save.restore(sess, save_path)
128      # Check that the parameter nodes have been restored.
129      self.assertEqual(10.0, self.evaluate(v0))
130      self.assertEqual(20.0, self.evaluate(v1))
131      self.assertEqual(b"k1", self.evaluate(v2.keys()))
132      self.assertEqual(30.0, self.evaluate(v2.values()))
133
134    # Build another graph with 2 nodes, initialized
135    # differently, and a Restore node for them.
136    with self.test_session(graph=ops_lib.Graph()) as sess:
137      v0_2 = variable_op(1000.0, name="v0")
138      v1_2 = variable_op(2000.0, name="v1")
139      v2_2 = saver_test_utils.CheckpointedOp(name="v2")
140      v2_init = v2_2.insert("k1000", 3000.0)
141
142      # Check that the parameter nodes have been initialized.
143      if context.in_graph_mode():
144        init_all_op = [variables.global_variables_initializer(), v2_init]
145        self.evaluate(init_all_op)
146        # TODO(xpan): Why _mutable_hash_table_v2 doesn't create empty
147        # table as it claims in eager mode?
148        self.assertEqual(b"k1000", self.evaluate(v2_2.keys()))
149        self.assertEqual(3000.0, self.evaluate(v2_2.values()))
150      self.assertEqual(1000.0, self.evaluate(v0_2))
151      self.assertEqual(2000.0, self.evaluate(v1_2))
152
153      # Restore the values saved earlier in the parameter nodes.
154      save2 = saver_module.Saver({"v0": v0_2, "v1": v1_2, "v2": v2_2.saveable})
155      save2.restore(sess, save_path)
156      # Check that the parameter nodes have been restored.
157      self.assertEqual(10.0, self.evaluate(v0_2))
158      self.assertEqual(20.0, self.evaluate(v1_2))
159      self.assertEqual(b"k1", self.evaluate(v2_2.keys()))
160      self.assertEqual(30.0, self.evaluate(v2_2.values()))
161
162  def testBasic(self):
163    self.basicSaveRestore(variables.Variable)
164
165  @test_util.run_in_graph_and_eager_modes()
166  def testResourceBasic(self):
167    self.basicSaveRestore(resource_variable_ops.ResourceVariable)
168
169  def testResourceVariableReadOpsAddedDeterministically(self):
170    graph_defs = []
171    num_graphs = 10
172    for _ in range(num_graphs):
173      with ops_lib.Graph().as_default() as g:
174        for i in range(20):
175          resource_variable_ops.ResourceVariable(i, name="var%s" % i)
176        saver_module.Saver()
177        graph_defs.append(g.as_graph_def())
178    for i in range(num_graphs - 1):
179      self.assertEqual(graph_defs[i], graph_defs[i + 1])
180
181  def testEagerBasic(self):
182    with context.eager_mode():
183      ckpt_prefix = os.path.join(self.get_temp_dir(), "ckpt")
184
185      v1 = resource_variable_ops.ResourceVariable(3.14, name="v1")
186      v2 = resource_variable_ops.ResourceVariable([1, 2], name="v2")
187      save = saver_module.Saver([v1, v2])
188      save.save(None, ckpt_prefix)
189
190      v1.assign(0.0)
191      v2.assign([0, 0])
192      self.assertNear(0.0, self.evaluate(v1), 1e-5)
193      self.assertAllEqual([0, 0], self.evaluate(v2))
194
195      save.restore(None, ckpt_prefix)
196      self.assertNear(3.14, self.evaluate(v1), 1e-5)
197      self.assertAllEqual([1, 2], self.evaluate(v2))
198
199  def testEagerGraphCompatibility(self):
200    # Save from graph mode and restore from eager mode.
201    graph_ckpt_prefix = os.path.join(self.get_temp_dir(), "graph_ckpt")
202    with context.graph_mode():
203      with self.test_session(graph=ops_lib.Graph()) as sess:
204        # Create a graph model and save the checkpoint.
205        w1 = resource_variable_ops.ResourceVariable(1.0, name="w1")
206        w2 = resource_variable_ops.ResourceVariable(2.0, name="w2")
207        graph_saver = saver_module.Saver([w1, w2])
208        sess.run(variables.global_variables_initializer())
209        graph_saver.save(sess, graph_ckpt_prefix)
210
211    with context.eager_mode():
212      ops_lib._default_graph_stack.reset()  # pylint: disable=protected-access
213      ops_lib.reset_default_graph()
214
215      w1 = resource_variable_ops.ResourceVariable(0.0, name="w1")
216      w2 = resource_variable_ops.ResourceVariable(0.0, name="w2")
217
218      graph_saver = saver_module.Saver([w1, w2])
219      graph_saver.restore(None, graph_ckpt_prefix)
220
221      self.assertAllEqual(self.evaluate(w1), 1.0)
222      self.assertAllEqual(self.evaluate(w2), 2.0)
223
224    # Save from eager mode and restore from graph mode.
225    eager_ckpt_prefix = os.path.join(self.get_temp_dir(), "eager_ckpt")
226    with context.eager_mode():
227      ops_lib._default_graph_stack.reset()  # pylint: disable=protected-access
228      ops_lib.reset_default_graph()
229
230      w3 = resource_variable_ops.ResourceVariable(3.0, name="w3")
231      w4 = resource_variable_ops.ResourceVariable(4.0, name="w4")
232
233      graph_saver = saver_module.Saver([w3, w4])
234      graph_saver.save(None, eager_ckpt_prefix)
235
236    with context.graph_mode():
237      with self.test_session(graph=ops_lib.Graph()) as sess:
238        w3 = resource_variable_ops.ResourceVariable(0.0, name="w3")
239        w4 = resource_variable_ops.ResourceVariable(0.0, name="w4")
240        graph_saver = saver_module.Saver([w3, w4])
241        sess.run(variables.global_variables_initializer())
242        graph_saver.restore(sess, eager_ckpt_prefix)
243        self.assertAllEqual(w3.eval(), 3.0)
244        self.assertAllEqual(w4.eval(), 4.0)
245
246  @test_util.run_in_graph_and_eager_modes()
247  def testResourceSaveRestoreCachingDevice(self):
248    save_path = os.path.join(self.get_temp_dir(), "resource_cache")
249    with self.test_session(graph=ops_lib.Graph()) as sess:
250      v = resource_variable_ops.ResourceVariable([1], caching_device="/cpu:0",
251                                                 name="v")
252      if context.in_graph_mode():
253        self.evaluate(variables.global_variables_initializer())
254      else:
255        sess = None
256      save = saver_module.Saver([v])
257      save.save(sess, save_path)
258
259      save2 = saver_module.Saver([v])
260      save2.restore(sess, save_path)
261      self.assertEquals(self.evaluate(v), [1])
262
263  def testSaveCopyRestoreWithSaveRelativePaths(self):
264    """Save, copy checkpoint dir and restore from copied dir.
265
266    This only works for save_relative_paths=True.
267    """
268    save_dir1 = os.path.join(self.get_temp_dir(), "save_dir1")
269    os.mkdir(save_dir1)
270    save_path1 = os.path.join(save_dir1, "save_copy_restore")
271
272    # Build a graph with 2 parameter nodes, and Save and
273    # Restore nodes for them.
274    v0 = variables.Variable(10.0, name="v0")
275    v1 = variables.Variable(20.0, name="v1")
276    v2 = saver_test_utils.CheckpointedOp(name="v2")
277    v2_init = v2.insert("k1", 30.0)
278    save = saver_module.Saver(
279        var_list={
280            "v0": v0,
281            "v1": v1,
282            "v2": v2.saveable},
283        restore_sequentially=True,
284        save_relative_paths=True)
285    init_all_op = [variables.global_variables_initializer(), v2_init]
286
287    with self.test_session() as sess:
288      # Initialize all variables
289      sess.run(init_all_op)
290
291      # Check that the parameter nodes have been initialized.
292      self.assertEqual(10.0, v0.eval())
293      self.assertEqual(20.0, v1.eval())
294      self.assertEqual(b"k1", v2.keys().eval())
295      self.assertEqual(30.0, v2.values().eval())
296
297      # Save the initialized values in the file at "save_path"
298      val = save.save(sess, save_path1)
299      self.assertTrue(isinstance(val, six.string_types))
300      self.assertEqual(save_path1, val)
301
302    self.assertEqual(saver_module.latest_checkpoint(save_dir1), save_path1)
303    save_dir2 = os.path.join(self.get_temp_dir(), "save_dir2")
304    os.renames(save_dir1, save_dir2)
305    save_path2 = os.path.join(save_dir2, "save_copy_restore")
306    self.assertEqual(saver_module.latest_checkpoint(save_dir2), save_path2)
307
308    # Start a second session.  In that session the parameter nodes
309    # have not been initialized either.
310    with self.test_session() as sess:
311      v0 = variables.Variable(-1.0, name="v0")
312      v1 = variables.Variable(-1.0, name="v1")
313      v2 = saver_test_utils.CheckpointedOp(name="v2")
314      save = saver_module.Saver({"v0": v0, "v1": v1, "v2": v2.saveable})
315
316      # Assert that the variables are not initialized.
317      self.assertEqual(
318          len(variables.report_uninitialized_variables().eval()), 2)
319      self.assertEqual(0, len(v2.keys().eval()))
320      self.assertEqual(0, len(v2.values().eval()))
321
322      # Restore the saved values in the parameter nodes.
323      save.restore(sess, save_path2)
324      # Check that the parameter nodes have been restored.
325      self.assertEqual(10.0, v0.eval())
326      self.assertEqual(20.0, v1.eval())
327      self.assertEqual(b"k1", v2.keys().eval())
328      self.assertEqual(30.0, v2.values().eval())
329
330  def testFilenameTensor(self):
331    v0 = variables.Variable(0, name="v0")
332    filename = b"somerandomfilename"
333    save = saver_module.Saver({"v0": v0}, filename=filename)
334    with self.test_session() as sess:
335      tensor = sess.graph.get_tensor_by_name(
336          save.saver_def.filename_tensor_name)
337      self.assertEqual(sess.run(tensor), filename)
338
339  def testInvalidPath(self):
340    v0 = variables.Variable(0, name="v0")
341    for ver in (saver_pb2.SaverDef.V1, saver_pb2.SaverDef.V2):
342      with self.test_session() as sess:
343        save = saver_module.Saver({"v0": v0}, write_version=ver)
344        with self.assertRaisesRegexp(errors.NotFoundError,
345                                     "Failed to find any matching files for"):
346          save.restore(sess, "invalid path")
347
348  def testInt64(self):
349    save_path = os.path.join(self.get_temp_dir(), "int64")
350
351    with self.test_session() as sess:
352      # Build a graph with 1 node, and save and restore for them.
353      v = variables.Variable(np.int64(15), name="v")
354      save = saver_module.Saver({"v": v}, restore_sequentially=True)
355      variables.global_variables_initializer().run()
356
357      # Save the initialized values in the file at "save_path"
358      val = save.save(sess, save_path)
359      self.assertTrue(isinstance(val, six.string_types))
360      self.assertEqual(save_path, val)
361
362      with self.test_session() as sess:
363        v = variables.Variable(np.int64(-1), name="v")
364        save = saver_module.Saver({"v": v})
365
366      with self.assertRaisesWithPredicateMatch(
367          errors_impl.OpError, lambda e: "uninitialized value v" in e.message):
368        sess.run(v)
369
370      # Restore the saved values in the parameter nodes.
371      save.restore(sess, save_path)
372      # Check that the parameter nodes have been restored.
373      self.assertEqual(np.int64(15), v.eval())
374
375  def testSomeErrors(self):
376    with ops_lib.Graph().as_default():
377      v0 = variables.Variable([10.0], name="v0")
378      v1 = variables.Variable([20.0], name="v1")
379      v2 = variables.Variable([20.0], name="v2")
380      v2._set_save_slice_info(
381          variables.Variable.SaveSliceInfo("v1", [1], [0], [1]))
382
383      # By default the name used for "v2" will be "v1" and raise an error.
384      with self.assertRaisesRegexp(ValueError, "same name: v1"):
385        saver_module.Saver([v0, v1, v2])
386
387      # The names are different and will work.
388      saver_module.Saver({"vee1": v1, "other": [v2]})
389
390      # Partitioned variables also cause name conflicts.
391      p_v1 = variable_scope.get_variable(
392          "p_v1",
393          shape=[4, 5],
394          partitioner=partitioned_variables.fixed_size_partitioner(
395              num_shards=2))
396      p_v2 = variable_scope.get_variable(
397          "p_v2",
398          shape=[4, 5],
399          partitioner=partitioned_variables.fixed_size_partitioner(
400              num_shards=2))
401      p_v2._name = "p_v1"
402      with self.assertRaisesRegexp(ValueError, "same name: p_v1"):
403        saver_module.Saver([p_v1, p_v2])
404
405  def testSameName(self):
406    with ops_lib.Graph().as_default():
407      v0 = variables.Variable([10.0], name="v0")
408      v2 = saver_test_utils.CheckpointedOp(name="v2")
409
410      # Saving one variable under two names raises an error.
411      with self.assertRaisesRegexp(
412          ValueError, "The same saveable will be restored with two names: v0"):
413        saver_module.Saver({"v0": v0, "v0too": v0})
414
415      # Ditto for custom saveables.
416      with self.assertRaisesRegexp(
417          ValueError, "The same saveable will be restored with two names: v2"):
418        saver_module.Saver({"v2": v2.saveable, "v2too": v2.saveable})
419
420      # Verify non-duplicate names work.
421      saver_module.Saver({"v0": v0, "v2": v2.saveable})
422
423  def testBasicsWithListOfVariables(self):
424    save_path = os.path.join(self.get_temp_dir(), "basics_with_list")
425
426    with self.test_session(graph=ops_lib.Graph()) as sess:
427      # Build a graph with 2 parameter nodes, and Save and
428      # Restore nodes for them.
429      v0 = variables.Variable(10.0, name="v0")
430      v1 = variables.Variable(20.0, name="v1")
431      v2 = saver_test_utils.CheckpointedOp(name="v2")
432      v2_init = v2.insert("k1", 30.0)
433      save = saver_module.Saver([v0, v1, v2.saveable])
434      variables.global_variables_initializer().run()
435      v2_init.run()
436
437      # Check that the parameter nodes have been initialized.
438      self.assertEqual(10.0, v0.eval())
439      self.assertEqual(20.0, v1.eval())
440      self.assertEqual(b"k1", v2.keys().eval())
441      self.assertEqual(30.0, v2.values().eval())
442
443      # Save the initialized values in the file at "save_path"
444      val = save.save(sess, save_path)
445      self.assertTrue(isinstance(val, six.string_types))
446      self.assertEqual(save_path, val)
447
448    # Start a second session.  In that session the variables
449    # have not been initialized either.
450    with self.test_session(graph=ops_lib.Graph()) as sess:
451      v0 = variables.Variable(-1.0, name="v0")
452      v1 = variables.Variable(-1.0, name="v1")
453      v2 = saver_test_utils.CheckpointedOp(name="v2")
454      save = saver_module.Saver([v0, v1, v2.saveable])
455
456      with self.assertRaisesWithPredicateMatch(
457          errors_impl.OpError, lambda e: "uninitialized value v0" in e.message):
458        sess.run(v0)
459      with self.assertRaisesWithPredicateMatch(
460          errors_impl.OpError, lambda e: "uninitialized value v1" in e.message):
461        sess.run(v1)
462      self.assertEqual(0, len(v2.keys().eval()))
463      self.assertEqual(0, len(v2.values().eval()))
464
465      # Restore the saved values in the parameter nodes.
466      save.restore(sess, save_path)
467      # Check that the parameter nodes have been restored.
468      self.assertEqual(10.0, v0.eval())
469      self.assertEqual(20.0, v1.eval())
470      self.assertEqual(b"k1", v2.keys().eval())
471      self.assertEqual(30.0, v2.values().eval())
472
473    # Build another graph with 2 nodes, initialized
474    # differently, and a Restore node for them.
475    with self.test_session(graph=ops_lib.Graph()) as sess:
476      v0_2 = variables.Variable(1000.0, name="v0")
477      v1_2 = variables.Variable(2000.0, name="v1")
478      v2_2 = saver_test_utils.CheckpointedOp(name="v2")
479      save2 = saver_module.Saver([v0_2, v1_2, v2_2.saveable])
480      v2_2.insert("k1000", 3000.0).run()
481      variables.global_variables_initializer().run()
482
483      # Check that the parameter nodes have been initialized.
484      self.assertEqual(1000.0, v0_2.eval())
485      self.assertEqual(2000.0, v1_2.eval())
486      self.assertEqual(b"k1000", v2_2.keys().eval())
487      self.assertEqual(3000.0, v2_2.values().eval())
488      # Restore the values saved earlier in the parameter nodes.
489      save2.restore(sess, save_path)
490      # Check that the parameter nodes have been restored.
491      self.assertEqual(10.0, v0_2.eval())
492      self.assertEqual(20.0, v1_2.eval())
493      self.assertEqual(b"k1", v2_2.keys().eval())
494      self.assertEqual(30.0, v2_2.values().eval())
495
496  def _SaveAndLoad(self, var_name, var_value, other_value, save_path):
497    with self.test_session(graph=ops_lib.Graph()) as sess:
498      var = resource_variable_ops.ResourceVariable(var_value, name=var_name)
499      save = saver_module.Saver({var_name: var})
500      if context.in_graph_mode():
501        self.evaluate(var.initializer)
502      val = save.save(sess, save_path)
503      self.assertEqual(save_path, val)
504    with self.test_session(graph=ops_lib.Graph()) as sess:
505      var = resource_variable_ops.ResourceVariable(other_value, name=var_name)
506      save = saver_module.Saver({var_name: var})
507      save.restore(sess, save_path)
508      self.assertAllClose(var_value, self.evaluate(var))
509
510  def testCacheRereadsFile(self):
511    save_path = os.path.join(self.get_temp_dir(), "cache_rereads")
512    # Save and reload one Variable named "var0".
513    self._SaveAndLoad("var0", 0.0, 1.0, save_path)
514    # Save and reload one Variable named "var1" in the same file.
515    # The cached readers should know to re-read the file.
516    self._SaveAndLoad("var1", 1.1, 2.2, save_path)
517
518  def testAllowEmpty(self):
519    save_path = os.path.join(self.get_temp_dir(), "allow_empty")
520    with self.test_session() as sess:
521      _ = constant_op.constant(1)
522      save = saver_module.Saver(allow_empty=True)
523      val = save.save(sess, save_path)
524      self.assertIsNone(val)
525    with self.test_session() as sess:
526      save = saver_module.Saver(allow_empty=True)
527      save.restore(sess, save_path)
528
529  def testGPU(self):
530    if not test.is_gpu_available():
531      return
532    save_path = os.path.join(self.get_temp_dir(), "gpu")
533    with session.Session("", graph=ops_lib.Graph()) as sess:
534      with sess.graph.device(test.gpu_device_name()):
535        v0_1 = variables.Variable(123.45)
536      save = saver_module.Saver({"v0": v0_1})
537      variables.global_variables_initializer().run()
538      save.save(sess, save_path)
539
540    with session.Session("", graph=ops_lib.Graph()) as sess:
541      with sess.graph.device(test.gpu_device_name()):
542        v0_2 = variables.Variable(543.21)
543      save = saver_module.Saver({"v0": v0_2})
544      variables.global_variables_initializer().run()
545
546  def testSharedServerOnGPU(self):
547    if not test.is_gpu_available():
548      return
549    save_path = os.path.join(self.get_temp_dir(), "gpu")
550    with session.Session("", graph=ops_lib.Graph()) as sess:
551      with sess.graph.device(test.gpu_device_name()):
552        v0_1 = variables.Variable(123.45)
553      save = saver_module.Saver({"v0": v0_1}, sharded=True, allow_empty=True)
554      variables.global_variables_initializer().run()
555      save.save(sess, save_path)
556
557    with session.Session("", graph=ops_lib.Graph()) as sess:
558      with sess.graph.device(test.gpu_device_name()):
559        v0_2 = variables.Variable(543.21)
560      save = saver_module.Saver({"v0": v0_2}, sharded=True, allow_empty=True)
561      variables.global_variables_initializer().run()
562
563  def testVariables(self):
564    save_path = os.path.join(self.get_temp_dir(), "variables")
565    with session.Session("", graph=ops_lib.Graph()) as sess:
566      one = variables.Variable(1.0)
567      twos = variables.Variable([2.0, 2.0, 2.0])
568      v2 = saver_test_utils.CheckpointedOp(name="v2")
569      init = variables.global_variables_initializer()
570      save = saver_module.Saver()
571      init.run()
572      v2.insert("k1", 3.0).run()
573      save.save(sess, save_path)
574
575    with session.Session("", graph=ops_lib.Graph()) as sess:
576      one = variables.Variable(0.0)
577      twos = variables.Variable([0.0, 0.0, 0.0])
578      v2 = saver_test_utils.CheckpointedOp(name="v2")
579      # Saver with no arg, defaults to 'all variables'.
580      save = saver_module.Saver()
581      save.restore(sess, save_path)
582      self.assertAllClose(1.0, one.eval())
583      self.assertAllClose([2.0, 2.0, 2.0], twos.eval())
584      self.assertEqual(b"k1", v2.keys().eval())
585      self.assertEqual(3.0, v2.values().eval())
586
587  def testVarListShouldBeEmptyInDeferredBuild(self):
588    with ops_lib.Graph().as_default():
589      v = variables.Variable(1.0)
590      with self.assertRaisesRegexp(ValueError, "defer_build"):
591        saver_module.Saver([v], defer_build=True)
592
593  def testBuildShouldBeCalledBeforeSaveInCaseOfDeferBuild(self):
594    save_path = os.path.join(self.get_temp_dir(), "error_deferred_build")
595    with ops_lib.Graph().as_default(), session.Session() as sess:
596      variables.Variable(1.0)
597      saver = saver_module.Saver(defer_build=True)
598      with self.assertRaisesRegexp(RuntimeError, "build"):
599        saver.save(sess, save_path)
600
601  def testDeferredBuild(self):
602    save_path = os.path.join(self.get_temp_dir(), "deferred_build")
603    with session.Session("", graph=ops_lib.Graph()) as sess:
604      one = variables.Variable(1.0)
605      save = saver_module.Saver(defer_build=True)
606      # if build is not deferred, saver cannot save the `twos`.
607      twos = variables.Variable([2.0, 2.0, 2.0])
608      init = variables.global_variables_initializer()
609      save.build()
610      init.run()
611      save.save(sess, save_path)
612
613    with session.Session("", graph=ops_lib.Graph()) as sess:
614      one = variables.Variable(0.0)
615      twos = variables.Variable([0.0, 0.0, 0.0])
616      # Saver with no arg, defaults to 'all variables'.
617      save = saver_module.Saver()
618      save.restore(sess, save_path)
619      self.assertAllClose(1.0, one.eval())
620      self.assertAllClose([2.0, 2.0, 2.0], twos.eval())
621
622  def testReshape(self):
623    save_path = os.path.join(self.get_temp_dir(), "variables_reshape")
624    with session.Session("", graph=ops_lib.Graph()) as sess:
625      var = variables.Variable([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
626      init = variables.global_variables_initializer()
627      save = saver_module.Saver()
628      init.run()
629      save.save(sess, save_path)
630
631    # Error when restoring with default reshape=False
632    with session.Session("", graph=ops_lib.Graph()) as sess:
633      var = variables.Variable([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]])
634      save = saver_module.Saver()
635      with self.assertRaisesRegexp(
636          errors_impl.InvalidArgumentError,
637          "Assign requires shapes of both tensors to match."):
638        save.restore(sess, save_path)
639
640    # Restored to new shape with reshape=True
641    with session.Session("", graph=ops_lib.Graph()) as sess:
642      var = variables.Variable([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]])
643      save = saver_module.Saver(reshape=True)
644      save.restore(sess, save_path)
645      self.assertAllClose([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], var.eval())
646
647  @test_util.run_in_graph_and_eager_modes()
648  def testSaveWithGlobalStep(self, pad_step_number=False):
649    save_path = os.path.join(self.get_temp_dir(), "ckpt_with_global_step")
650    global_step_int = 5
651    # Save and reload one Variable named "var0".
652    self._SaveAndLoad("var0", 0.0, 1.0, save_path)
653    for use_tensor in [True, False]:
654      with self.test_session(graph=ops_lib.Graph()):
655        var = resource_variable_ops.ResourceVariable(1.0, name="var0")
656        save = saver_module.Saver(
657            {
658                var._shared_name: var
659            }, pad_step_number=pad_step_number)
660        if context.in_graph_mode():
661          self.evaluate(var.initializer)
662          sess = ops_lib.get_default_session()
663        else:
664          sess = None
665        if use_tensor:
666          global_step = constant_op.constant(global_step_int)
667          val = save.save(sess, save_path, global_step=global_step)
668        else:
669          val = save.save(sess, save_path, global_step=global_step_int)
670        if pad_step_number:
671          expected_save_path = "%s-%s" % (save_path,
672                                          "{:08d}".format(global_step_int))
673        else:
674          expected_save_path = "%s-%d" % (save_path, global_step_int)
675        self.assertEqual(expected_save_path, val)
676
677  def testSaveWithGlobalStepWithPadding(self):
678    self.testSaveWithGlobalStep(pad_step_number=True)
679
680  def testSaveToNonexistingPath(self):
681    file_io.write_string_to_file(
682        os.path.join(self.get_temp_dir(), "actually_a_file"), "")
683    paths = [
684        os.path.join(self.get_temp_dir(), "nonexisting_dir/path"),
685        os.path.join(self.get_temp_dir(), "other_nonexisting_dir/path1/path2"),
686        os.path.join(self.get_temp_dir(), "actually_a_file/path"),
687    ]
688
689    for save_path in paths:
690      # Build a graph with 2 parameter nodes, and Save and
691      # Restore nodes for them.
692      v0 = variables.Variable(10.0, name="v0")
693      v1 = variables.Variable(20.0, name="v1")
694      save = saver_module.Saver({"v0": v0, "v1": v1}, restore_sequentially=True)
695      init_all_op = variables.global_variables_initializer()
696
697      # In the case where the parent directory doesn't exist, whether or not the
698      # save succeeds or fails is implementation dependent.  Therefore we allow
699      # both cases.
700      try:
701        with self.test_session() as sess:
702          # Initialize all variables
703          sess.run(init_all_op)
704
705          # Check that the parameter nodes have been initialized.
706          self.assertEqual(10.0, v0.eval())
707          self.assertEqual(20.0, v1.eval())
708
709          # Save the graph.
710          save.save(sess, save_path)
711
712        with self.test_session() as sess:
713          # Restore the saved values in the parameter nodes.
714          save.restore(sess, save_path)
715          # Check that the parameter nodes have been restored.
716          self.assertEqual(10.0, v0.eval())
717          self.assertEqual(20.0, v1.eval())
718      except ValueError as exc:
719        error_msg_template = "Parent directory of {} doesn't exist, can't save."
720        self.assertEqual(error_msg_template.format(save_path), str(exc))
721
722  def testSaveToURI(self):
723    # ParseURI functions don't work on Windows yet.
724    # TODO(jhseu): Remove this check when it works.
725    if os.name == "nt":
726      self.skipTest("Local URI support doesn't work on Windows")
727    save_path = "file://" + os.path.join(self.get_temp_dir(), "uri")
728
729    # Build a graph with 2 parameter nodes, and Save and
730    # Restore nodes for them.
731    v0 = variables.Variable(10.0, name="v0")
732    v1 = variables.Variable(20.0, name="v1")
733    save = saver_module.Saver({"v0": v0, "v1": v1}, restore_sequentially=True)
734    init_all_op = variables.global_variables_initializer()
735
736    with self.test_session() as sess:
737      # Initialize all variables
738      sess.run(init_all_op)
739
740      # Check that the parameter nodes have been initialized.
741      self.assertEqual(10.0, v0.eval())
742      self.assertEqual(20.0, v1.eval())
743      save.save(sess, save_path)
744
745
746@test_util.with_c_api
747class SaveRestoreShardedTest(test.TestCase):
748
749  _WRITE_VERSION = saver_pb2.SaverDef.V1
750
751  def _get_test_dir(self, dirname):
752    test_dir = os.path.join(self.get_temp_dir(), dirname)
753    gfile.MakeDirs(test_dir)
754    return test_dir
755
756  def testBasics(self):
757    save_path = os.path.join(self.get_temp_dir(), "sharded_basics")
758
759    # Build a graph with 2 parameter nodes on different devices.
760    with session.Session(
761        target="",
762        config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
763      with sess.graph.device("/cpu:0"):
764        v0 = variables.Variable(10, name="v0")
765        t0 = saver_test_utils.CheckpointedOp(name="t0")
766      with sess.graph.device("/cpu:1"):
767        v1 = variables.Variable(20, name="v1")
768        t1 = saver_test_utils.CheckpointedOp(name="t1")
769      save = saver_module.Saver(
770          {
771              "v0": v0,
772              "v1": v1,
773              "t0": t0.saveable,
774              "t1": t1.saveable
775          },
776          write_version=self._WRITE_VERSION,
777          sharded=True)
778      variables.global_variables_initializer().run()
779      t0.insert("k1", 30.0).run()
780      t1.insert("k2", 40.0).run()
781      val = save.save(sess, save_path)
782      if save._write_version is saver_pb2.SaverDef.V1:
783        self.assertEqual(save_path + "-?????-of-00002", val)
784      else:
785        self.assertEqual(save_path, val)
786      meta_graph_filename = save._MetaGraphFilename(val)
787      self.assertEqual(save_path + ".meta", meta_graph_filename)
788
789    if save._write_version is saver_pb2.SaverDef.V1:
790      # Restore different ops from shard 0 of the saved files.
791      with session.Session(
792          target="",
793          config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
794        with sess.graph.device("/cpu:0"):
795          v0 = variables.Variable(111, name="v0")
796          t0 = saver_test_utils.CheckpointedOp(name="t0")
797        save = saver_module.Saver(
798            {
799                "v0": v0,
800                "t0": t0.saveable
801            },
802            write_version=self._WRITE_VERSION,
803            sharded=True)
804        variables.global_variables_initializer().run()
805        t0.insert("k11", 33.0).run()
806        self.assertEqual(111, v0.eval())
807        self.assertEqual(b"k11", t0.keys().eval())
808        self.assertEqual(33.0, t0.values().eval())
809        save.restore(sess, save_path + "-00000-of-00002")
810        self.assertEqual(10, v0.eval())
811        self.assertEqual(b"k1", t0.keys().eval())
812        self.assertEqual(30.0, t0.values().eval())
813
814      # Restore different ops from shard 1 of the saved files.
815      with session.Session(
816          target="",
817          config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
818        with sess.graph.device("/cpu:0"):
819          v1 = variables.Variable(222)
820          t1 = saver_test_utils.CheckpointedOp(name="t1")
821        save = saver_module.Saver(
822            {
823                "v1": v1,
824                "t1": t1.saveable
825            },
826            write_version=self._WRITE_VERSION,
827            sharded=True)
828        variables.global_variables_initializer().run()
829        t1.insert("k22", 44.0).run()
830        self.assertEqual(222, v1.eval())
831        self.assertEqual(b"k22", t1.keys().eval())
832        self.assertEqual(44.0, t1.values().eval())
833        save.restore(sess, save_path + "-00001-of-00002")
834        self.assertEqual(20, v1.eval())
835        self.assertEqual(b"k2", t1.keys().eval())
836        self.assertEqual(40.0, t1.values().eval())
837
838    # Now try a restore with the sharded filename.
839    with session.Session(
840        target="",
841        config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
842      with sess.graph.device("/cpu:0"):
843        v0 = variables.Variable(111, name="v0")
844        t0 = saver_test_utils.CheckpointedOp(name="t0")
845      with sess.graph.device("/cpu:1"):
846        v1 = variables.Variable(222, name="v1")
847        t1 = saver_test_utils.CheckpointedOp(name="t1")
848      save = saver_module.Saver(
849          {
850              "v0": v0,
851              "v1": v1,
852              "t0": t0.saveable,
853              "t1": t1.saveable
854          },
855          write_version=self._WRITE_VERSION,
856          sharded=True)
857      variables.global_variables_initializer().run()
858      t0.insert("k11", 33.0).run()
859      t1.insert("k22", 44.0).run()
860      self.assertEqual(111, v0.eval())
861      self.assertEqual(222, v1.eval())
862      self.assertEqual(b"k11", t0.keys().eval())
863      self.assertEqual(33.0, t0.values().eval())
864      self.assertEqual(b"k22", t1.keys().eval())
865      self.assertEqual(44.0, t1.values().eval())
866      save_path = os.path.join(self.get_temp_dir(), "sharded_basics")
867      if save._write_version is saver_pb2.SaverDef.V1:
868        save.restore(sess, save_path + "-?????-of-?????")
869      else:
870        save.restore(sess, save_path)
871      self.assertEqual(10, v0.eval())
872      self.assertEqual(20, v1.eval())
873      self.assertEqual(b"k1", t0.keys().eval())
874      self.assertEqual(30.0, t0.values().eval())
875      self.assertEqual(b"k2", t1.keys().eval())
876      self.assertEqual(40.0, t1.values().eval())
877
878    if save._write_version is saver_pb2.SaverDef.V1:
879      self.assertEqual(
880          saver_module.latest_checkpoint(self.get_temp_dir()),
881          os.path.join(self.get_temp_dir(), "sharded_basics-?????-of-00002"))
882    else:
883      self.assertEqual(
884          saver_module.latest_checkpoint(self.get_temp_dir()),
885          os.path.join(self.get_temp_dir(), "sharded_basics"))
886
887  def testSaverDef(self):
888    with self.test_session():
889      v0 = variables.Variable(123, name="v0")
890      save = saver_module.Saver({"v0": v0}, sharded=True)
891      sd = save.as_saver_def()
892      self.assertTrue(sd.sharded)
893
894  def _testPartitionedVariables(self, use_resource):
895    var_full_shape = [10, 3]
896    # Allows save/restore mechanism to work w/ different slicings.
897    var_name = "my_var"
898    saved_dir = self._get_test_dir("partitioned_variables")
899    saved_path = os.path.join(saved_dir, "ckpt")
900
901    call_saver_with_dict = False  # updated by test loop below
902
903    def _save(slices=None, partitioner=None):
904      with self.test_session(graph=ops_lib.Graph()) as sess:
905        # Calls .eval() to return the ndarray that makes up the full variable.
906        rnd = random_ops.random_uniform(var_full_shape).eval()
907
908        if slices:
909          assert not partitioner
910          # TODO(apassos): make create_partitioned_variables take use_resource
911          # option to make this test passable without creating a named
912          # variable_scope.
913          vs = partitioned_variables.create_partitioned_variables(
914              var_full_shape, slices, rnd, name=var_name)
915        elif partitioner:
916          vs = [
917              variable_scope.get_variable(
918                  var_name,
919                  shape=var_full_shape,
920                  initializer=rnd,
921                  partitioner=partitioner,
922                  use_resource=use_resource)
923          ]
924        else:
925          if use_resource:
926            vs = [resource_variable_ops.ResourceVariable(rnd, name=var_name)]
927          else:
928            vs = [variables.Variable(rnd, name=var_name)]
929
930        variables.global_variables_initializer().run()
931        if call_saver_with_dict:
932          saver = saver_module.Saver({var_name: (vs if slices else vs[0])})
933        else:
934          saver = saver_module.Saver(vs)
935        actual_path = saver.save(sess, saved_path)
936        self.assertEqual(saved_path, actual_path)
937
938        return rnd
939
940    def _restore(slices=None, partitioner=None):
941      with self.test_session(graph=ops_lib.Graph()) as sess:
942        if slices:
943          assert not partitioner
944          new_vs = partitioned_variables.create_partitioned_variables(
945              var_full_shape,
946              slices,
947              array_ops.zeros(var_full_shape),  # != original contents.
948              name=var_name)
949        elif partitioner:
950          new_vs = [
951              variable_scope.get_variable(
952                  var_name,
953                  shape=var_full_shape,
954                  initializer=array_ops.zeros(var_full_shape),
955                  partitioner=partitioner)
956          ]
957        else:
958          new_vs = [
959              variables.Variable(
960                  array_ops.zeros(
961                      shape=var_full_shape),  # != original contents.
962                  name=var_name)
963          ]
964
965        variables.global_variables_initializer().run()
966        if call_saver_with_dict:
967          saver = saver_module.Saver({
968              var_name: (new_vs if slices else new_vs[0])
969          })
970        else:
971          saver = saver_module.Saver(new_vs)
972        saver.restore(sess, saved_path)
973
974        if partitioner:
975          return new_vs[0].as_tensor().eval()
976        elif slices and slices[0] != 1:
977          return array_ops.concat(new_vs, 0).eval()
978        elif slices and slices[1] != 1:
979          return array_ops.concat(new_vs, 1).eval()
980        else:  # Non-sliced.
981          return new_vs[0].eval()
982
983    for call_saver_with_dict in {False, True}:
984      # Save PartitionedVariable and restore into full variable.
985      saved_full = _save(
986          partitioner=partitioned_variables.fixed_size_partitioner(
987              num_shards=2))
988      restored_full = _restore()
989      self.assertAllEqual(saved_full, restored_full)
990
991      # Saves 10 horizontal parts of a partitioned variable.
992      # Restores into a full variable, non-sliced.
993      saved_full = _save(slices=[10, 1])
994      restored_full = _restore()
995      self.assertAllEqual(saved_full, restored_full)
996
997      # Restores into a different number/orientation of slices.
998      restored_full = _restore(slices=[2, 1])  # 2 horizon parts.
999      self.assertAllEqual(saved_full, restored_full)
1000      restored_full = _restore(slices=[1, 3])  # 3 vertical parts.
1001      self.assertAllEqual(saved_full, restored_full)
1002
1003      # Restores into a PartitionedVariable
1004      restored_full = _restore(
1005          partitioner=partitioned_variables.fixed_size_partitioner(
1006              num_shards=2))
1007      self.assertAllEqual(saved_full, restored_full)
1008
1009      # Now, saves a full variable and restores in slices.
1010      saved_full = _save()
1011      restored_full = _restore(slices=[1, 3])
1012      self.assertAllEqual(saved_full, restored_full)
1013
1014  def testPartitionedVariable(self):
1015    self._testPartitionedVariables(use_resource=False)
1016
1017  def testPartitionedResourceVariable(self):
1018    self._testPartitionedVariables(use_resource=True)
1019
1020
1021@test_util.with_c_api
1022class SaveRestoreShardedTestV2(SaveRestoreShardedTest):
1023  _WRITE_VERSION = saver_pb2.SaverDef.V2
1024
1025
1026@test_util.with_c_api
1027class MaxToKeepTest(test.TestCase):
1028
1029  def _get_test_dir(self, dirname):
1030    test_dir = os.path.join(self.get_temp_dir(), dirname)
1031    gfile.MakeDirs(test_dir)
1032    return test_dir
1033
1034  def assertCheckpointState(self, model_checkpoint_path,
1035                            all_model_checkpoint_paths, save_dir):
1036    checkpoint_state = saver_module.get_checkpoint_state(save_dir)
1037    self.assertEqual(checkpoint_state.model_checkpoint_path,
1038                     model_checkpoint_path)
1039    self.assertEqual(checkpoint_state.all_model_checkpoint_paths,
1040                     all_model_checkpoint_paths)
1041
1042  def testNonSharded(self):
1043    save_dir = self._get_test_dir("max_to_keep_non_sharded")
1044
1045    with self.test_session() as sess:
1046      v = variables.Variable(10.0, name="v")
1047      save = saver_module.Saver({"v": v}, max_to_keep=2)
1048      variables.global_variables_initializer().run()
1049      self.assertEqual([], save.last_checkpoints)
1050
1051      s1 = save.save(sess, os.path.join(save_dir, "s1"))
1052      self.assertEqual([s1], save.last_checkpoints)
1053      self.assertTrue(saver_module.checkpoint_exists(s1))
1054      self.assertCheckpointState(
1055          model_checkpoint_path=s1,
1056          all_model_checkpoint_paths=[s1],
1057          save_dir=save_dir)
1058
1059      s2 = save.save(sess, os.path.join(save_dir, "s2"))
1060      self.assertEqual([s1, s2], save.last_checkpoints)
1061      self.assertTrue(saver_module.checkpoint_exists(s1))
1062      self.assertTrue(saver_module.checkpoint_exists(s2))
1063      self.assertCheckpointState(
1064          model_checkpoint_path=s2,
1065          all_model_checkpoint_paths=[s1, s2],
1066          save_dir=save_dir)
1067
1068      s3 = save.save(sess, os.path.join(save_dir, "s3"))
1069      self.assertEqual([s2, s3], save.last_checkpoints)
1070      self.assertFalse(saver_module.checkpoint_exists(s1))
1071      self.assertTrue(saver_module.checkpoint_exists(s2))
1072      self.assertTrue(saver_module.checkpoint_exists(s3))
1073      self.assertCheckpointState(
1074          model_checkpoint_path=s3,
1075          all_model_checkpoint_paths=[s2, s3],
1076          save_dir=save_dir)
1077
1078      # Create a second helper, identical to the first.
1079      save2 = saver_module.Saver(saver_def=save.as_saver_def())
1080      save2.set_last_checkpoints(save.last_checkpoints)
1081
1082      # Create a third helper, with the same configuration but no knowledge of
1083      # previous checkpoints.
1084      save3 = saver_module.Saver(saver_def=save.as_saver_def())
1085
1086      # Exercise the first helper.
1087
1088      # Adding s2 again (old s2 is removed first, then new s2 appended)
1089      s2 = save.save(sess, os.path.join(save_dir, "s2"))
1090      self.assertEqual([s3, s2], save.last_checkpoints)
1091      self.assertFalse(saver_module.checkpoint_exists(s1))
1092      self.assertFalse(
1093          saver_module.checkpoint_exists(save._MetaGraphFilename(s1)))
1094      self.assertTrue(saver_module.checkpoint_exists(s3))
1095      self.assertTrue(
1096          saver_module.checkpoint_exists(save._MetaGraphFilename(s3)))
1097      self.assertTrue(saver_module.checkpoint_exists(s2))
1098      self.assertTrue(
1099          saver_module.checkpoint_exists(save._MetaGraphFilename(s2)))
1100      self.assertCheckpointState(
1101          model_checkpoint_path=s2,
1102          all_model_checkpoint_paths=[s3, s2],
1103          save_dir=save_dir)
1104
1105      # Adding s1 (s3 should now be deleted as oldest in list)
1106      s1 = save.save(sess, os.path.join(save_dir, "s1"))
1107      self.assertEqual([s2, s1], save.last_checkpoints)
1108      self.assertFalse(saver_module.checkpoint_exists(s3))
1109      self.assertFalse(
1110          saver_module.checkpoint_exists(save._MetaGraphFilename(s3)))
1111      self.assertTrue(saver_module.checkpoint_exists(s2))
1112      self.assertTrue(
1113          saver_module.checkpoint_exists(save._MetaGraphFilename(s2)))
1114      self.assertTrue(saver_module.checkpoint_exists(s1))
1115      self.assertTrue(
1116          saver_module.checkpoint_exists(save._MetaGraphFilename(s1)))
1117      self.assertCheckpointState(
1118          model_checkpoint_path=s1,
1119          all_model_checkpoint_paths=[s2, s1],
1120          save_dir=save_dir)
1121
1122      # Exercise the second helper.
1123
1124      # Adding s2 again (old s2 is removed first, then new s2 appended)
1125      s2 = save2.save(sess, os.path.join(save_dir, "s2"))
1126      self.assertEqual([s3, s2], save2.last_checkpoints)
1127      # Created by the first helper.
1128      self.assertTrue(saver_module.checkpoint_exists(s1))
1129      self.assertTrue(
1130          saver_module.checkpoint_exists(save._MetaGraphFilename(s1)))
1131      # Deleted by the first helper.
1132      self.assertFalse(saver_module.checkpoint_exists(s3))
1133      self.assertFalse(
1134          saver_module.checkpoint_exists(save._MetaGraphFilename(s3)))
1135      self.assertTrue(saver_module.checkpoint_exists(s2))
1136      self.assertTrue(
1137          saver_module.checkpoint_exists(save._MetaGraphFilename(s2)))
1138      self.assertCheckpointState(
1139          model_checkpoint_path=s2,
1140          all_model_checkpoint_paths=[s3, s2],
1141          save_dir=save_dir)
1142
1143      # Adding s1 (s3 should now be deleted as oldest in list)
1144      s1 = save2.save(sess, os.path.join(save_dir, "s1"))
1145      self.assertEqual([s2, s1], save2.last_checkpoints)
1146      self.assertFalse(saver_module.checkpoint_exists(s3))
1147      self.assertFalse(
1148          saver_module.checkpoint_exists(save._MetaGraphFilename(s3)))
1149      self.assertTrue(saver_module.checkpoint_exists(s2))
1150      self.assertTrue(
1151          saver_module.checkpoint_exists(save._MetaGraphFilename(s2)))
1152      self.assertTrue(saver_module.checkpoint_exists(s1))
1153      self.assertTrue(
1154          saver_module.checkpoint_exists(save._MetaGraphFilename(s1)))
1155      self.assertCheckpointState(
1156          model_checkpoint_path=s1,
1157          all_model_checkpoint_paths=[s2, s1],
1158          save_dir=save_dir)
1159
1160      # Exercise the third helper.
1161
1162      # Adding s2 again (but helper is unaware of previous s2)
1163      s2 = save3.save(sess, os.path.join(save_dir, "s2"))
1164      self.assertEqual([s2], save3.last_checkpoints)
1165      # Created by the first helper.
1166      self.assertTrue(saver_module.checkpoint_exists(s1))
1167      self.assertTrue(
1168          saver_module.checkpoint_exists(save._MetaGraphFilename(s1)))
1169      # Deleted by the first helper.
1170      self.assertFalse(saver_module.checkpoint_exists(s3))
1171      self.assertFalse(
1172          saver_module.checkpoint_exists(save._MetaGraphFilename(s3)))
1173      self.assertTrue(saver_module.checkpoint_exists(s2))
1174      self.assertTrue(
1175          saver_module.checkpoint_exists(save._MetaGraphFilename(s2)))
1176      # Even though the file for s1 exists, this saver isn't aware of it, which
1177      # is why it doesn't end up in the checkpoint state.
1178      self.assertCheckpointState(
1179          model_checkpoint_path=s2,
1180          all_model_checkpoint_paths=[s2],
1181          save_dir=save_dir)
1182
1183      # Adding s1 (s3 should not be deleted because helper is unaware of it)
1184      s1 = save3.save(sess, os.path.join(save_dir, "s1"))
1185      self.assertEqual([s2, s1], save3.last_checkpoints)
1186      self.assertFalse(saver_module.checkpoint_exists(s3))
1187      self.assertFalse(
1188          saver_module.checkpoint_exists(save._MetaGraphFilename(s3)))
1189      self.assertTrue(saver_module.checkpoint_exists(s2))
1190      self.assertTrue(
1191          saver_module.checkpoint_exists(save._MetaGraphFilename(s2)))
1192      self.assertTrue(saver_module.checkpoint_exists(s1))
1193      self.assertTrue(
1194          saver_module.checkpoint_exists(save._MetaGraphFilename(s1)))
1195      self.assertCheckpointState(
1196          model_checkpoint_path=s1,
1197          all_model_checkpoint_paths=[s2, s1],
1198          save_dir=save_dir)
1199
1200  def testSharded(self):
1201    save_dir = self._get_test_dir("max_to_keep_sharded")
1202
1203    with session.Session(
1204        target="",
1205        config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
1206      with sess.graph.device("/cpu:0"):
1207        v0 = variables.Variable(111, name="v0")
1208      with sess.graph.device("/cpu:1"):
1209        v1 = variables.Variable(222, name="v1")
1210      save = saver_module.Saver(
1211          {
1212              "v0": v0,
1213              "v1": v1
1214          }, sharded=True, max_to_keep=2)
1215      variables.global_variables_initializer().run()
1216      self.assertEqual([], save.last_checkpoints)
1217
1218      s1 = save.save(sess, os.path.join(save_dir, "s1"))
1219      self.assertEqual([s1], save.last_checkpoints)
1220      if save._write_version is saver_pb2.SaverDef.V1:
1221        self.assertEqual(2, len(gfile.Glob(s1)))
1222      else:
1223        self.assertEqual(4, len(gfile.Glob(s1 + "*")))
1224
1225      self.assertTrue(gfile.Exists(save._MetaGraphFilename(s1)))
1226
1227      s2 = save.save(sess, os.path.join(save_dir, "s2"))
1228      self.assertEqual([s1, s2], save.last_checkpoints)
1229      if save._write_version is saver_pb2.SaverDef.V1:
1230        self.assertEqual(2, len(gfile.Glob(s1)))
1231      else:
1232        self.assertEqual(4, len(gfile.Glob(s1 + "*")))
1233      self.assertTrue(gfile.Exists(save._MetaGraphFilename(s1)))
1234      if save._write_version is saver_pb2.SaverDef.V1:
1235        self.assertEqual(2, len(gfile.Glob(s2)))
1236      else:
1237        self.assertEqual(4, len(gfile.Glob(s2 + "*")))
1238      self.assertTrue(gfile.Exists(save._MetaGraphFilename(s2)))
1239
1240      s3 = save.save(sess, os.path.join(save_dir, "s3"))
1241      self.assertEqual([s2, s3], save.last_checkpoints)
1242      self.assertEqual(0, len(gfile.Glob(s1 + "*")))
1243      self.assertFalse(gfile.Exists(save._MetaGraphFilename(s1)))
1244      if save._write_version is saver_pb2.SaverDef.V1:
1245        self.assertEqual(2, len(gfile.Glob(s2)))
1246      else:
1247        self.assertEqual(4, len(gfile.Glob(s2 + "*")))
1248      self.assertTrue(gfile.Exists(save._MetaGraphFilename(s2)))
1249      if save._write_version is saver_pb2.SaverDef.V1:
1250        self.assertEqual(2, len(gfile.Glob(s3)))
1251      else:
1252        self.assertEqual(4, len(gfile.Glob(s3 + "*")))
1253      self.assertTrue(gfile.Exists(save._MetaGraphFilename(s3)))
1254
1255  def testNoMaxToKeep(self):
1256    save_dir = self._get_test_dir("no_max_to_keep")
1257    save_dir2 = self._get_test_dir("max_to_keep_0")
1258
1259    with self.test_session() as sess:
1260      v = variables.Variable(10.0, name="v")
1261      variables.global_variables_initializer().run()
1262
1263      # Test max_to_keep being None.
1264      save = saver_module.Saver({"v": v}, max_to_keep=None)
1265      self.assertEqual([], save.last_checkpoints)
1266      s1 = save.save(sess, os.path.join(save_dir, "s1"))
1267      self.assertEqual([], save.last_checkpoints)
1268      self.assertTrue(saver_module.checkpoint_exists(s1))
1269      s2 = save.save(sess, os.path.join(save_dir, "s2"))
1270      self.assertEqual([], save.last_checkpoints)
1271      self.assertTrue(saver_module.checkpoint_exists(s2))
1272
1273      # Test max_to_keep being 0.
1274      save2 = saver_module.Saver({"v": v}, max_to_keep=0)
1275      self.assertEqual([], save2.last_checkpoints)
1276      s1 = save2.save(sess, os.path.join(save_dir2, "s1"))
1277      self.assertEqual([], save2.last_checkpoints)
1278      self.assertTrue(saver_module.checkpoint_exists(s1))
1279      s2 = save2.save(sess, os.path.join(save_dir2, "s2"))
1280      self.assertEqual([], save2.last_checkpoints)
1281      self.assertTrue(saver_module.checkpoint_exists(s2))
1282
1283  def testNoMetaGraph(self):
1284    save_dir = self._get_test_dir("no_meta_graph")
1285
1286    with self.test_session() as sess:
1287      v = variables.Variable(10.0, name="v")
1288      save = saver_module.Saver({"v": v})
1289      variables.global_variables_initializer().run()
1290
1291      s1 = save.save(sess, os.path.join(save_dir, "s1"), write_meta_graph=False)
1292      self.assertTrue(saver_module.checkpoint_exists(s1))
1293      self.assertFalse(gfile.Exists(save._MetaGraphFilename(s1)))
1294
1295
1296@test_util.with_c_api
1297class KeepCheckpointEveryNHoursTest(test.TestCase):
1298
1299  def _get_test_dir(self, dirname):
1300    test_dir = os.path.join(self.get_temp_dir(), dirname)
1301    gfile.MakeDirs(test_dir)
1302    return test_dir
1303
1304  @test.mock.patch.object(saver_module, "time")
1305  def testNonSharded(self, mock_time):
1306    save_dir = self._get_test_dir("keep_checkpoint_every_n_hours")
1307
1308    with self.test_session() as sess:
1309      v = variables.Variable([10.0], name="v")
1310      # Run the initializer NOW to avoid the 0.5s overhead of the first Run()
1311      # call, which throws the test timing off in fastbuild mode.
1312      variables.global_variables_initializer().run()
1313      # Create a saver that will keep the last 2 checkpoints plus one every 0.7
1314      # seconds.
1315      start_time = time.time()
1316      mock_time.time.return_value = start_time
1317      save = saver_module.Saver(
1318          {
1319              "v": v
1320          }, max_to_keep=2, keep_checkpoint_every_n_hours=0.7 / 3600)
1321      self.assertEqual([], save.last_checkpoints)
1322
1323      # Wait till 1 seconds have elapsed so s1 will be old enough to keep.
1324      # sleep may return early, don't trust it.
1325      mock_time.time.return_value = start_time + 1.0
1326      s1 = save.save(sess, os.path.join(save_dir, "s1"))
1327      self.assertEqual([s1], save.last_checkpoints)
1328
1329      s2 = save.save(sess, os.path.join(save_dir, "s2"))
1330      self.assertEqual([s1, s2], save.last_checkpoints)
1331
1332      # We now have 2 'last_checkpoints': [s1, s2].  The next call to Save(),
1333      # would normally delete s1, because max_to_keep is 2.  However, s1 is
1334      # older than 0.7s so we must keep it.
1335      s3 = save.save(sess, os.path.join(save_dir, "s3"))
1336      self.assertEqual([s2, s3], save.last_checkpoints)
1337
1338      # s1 should still be here, we are Not checking now to reduce time
1339      # variance in the test.
1340
1341      # We now have 2 'last_checkpoints': [s2, s3], and s1 on disk.  The next
1342      # call to Save(), will delete s2, because max_to_keep is 2, and because
1343      # we already kept the old s1. s2 is very close in time to s1 so it gets
1344      # deleted.
1345      s4 = save.save(sess, os.path.join(save_dir, "s4"))
1346      self.assertEqual([s3, s4], save.last_checkpoints)
1347
1348      # Check that s1 is still here, but s2 is gone.
1349      self.assertTrue(saver_module.checkpoint_exists(s1))
1350      self.assertFalse(saver_module.checkpoint_exists(s2))
1351      self.assertTrue(saver_module.checkpoint_exists(s3))
1352      self.assertTrue(saver_module.checkpoint_exists(s4))
1353
1354
1355@test_util.with_c_api
1356class SaveRestoreWithVariableNameMap(test.TestCase):
1357
1358  def _testNonReshape(self, variable_op):
1359    save_path = os.path.join(self.get_temp_dir(), "non_reshape")
1360
1361    with self.test_session(graph=ops_lib.Graph()) as sess:
1362      # Build a graph with 2 parameter nodes, and Save and
1363      # Restore nodes for them.
1364      v0 = variable_op(10.0, name="v0")
1365      v1 = variable_op(20.0, name="v1")
1366      save = saver_module.Saver({"save_prefix/v0": v0, "save_prefix/v1": v1})
1367      self.evaluate(variables.global_variables_initializer())
1368
1369      # Check that the parameter nodes have been initialized.
1370      self.assertEqual(10.0, self.evaluate(v0))
1371      self.assertEqual(20.0, self.evaluate(v1))
1372
1373      # Save the initialized values in the file at "save_path"
1374      # Use a variable name map to set the saved tensor names
1375      val = save.save(sess, save_path)
1376      self.assertTrue(isinstance(val, six.string_types))
1377      self.assertEqual(save_path, val)
1378
1379      # Verify that the original names are not in the Saved file
1380      save = saver_module.Saver({"v0": v0, "v1": v1})
1381      with self.assertRaisesOpError("not found in checkpoint"):
1382        save.restore(sess, save_path)
1383
1384    # Verify that the mapped names are present in the Saved file and can be
1385    # Restored using remapped names.
1386    with self.test_session(graph=ops_lib.Graph()) as sess:
1387      v0 = variable_op(-1.0, name="v0")
1388      v1 = variable_op(-1.0, name="v1")
1389
1390      if context.in_graph_mode():
1391        with self.assertRaisesOpError("uninitialized"):
1392          self.evaluate(v0)
1393        with self.assertRaisesOpError("uninitialized"):
1394          self.evaluate(v1)
1395
1396      save = saver_module.Saver({"save_prefix/v0": v0, "save_prefix/v1": v1})
1397      save.restore(sess, save_path)
1398
1399      # Check that the parameter nodes have been restored.
1400      if context.in_graph_mode():
1401        self.assertEqual(10.0, self.evaluate(v0))
1402        self.assertEqual(20.0, self.evaluate(v1))
1403
1404    # Add a prefix to the node names in the current graph and Restore using
1405    # remapped names.
1406    with self.test_session(graph=ops_lib.Graph()) as sess:
1407      v0 = variable_op(-1.0, name="restore_prefix/v0")
1408      v1 = variable_op(-1.0, name="restore_prefix/v1")
1409
1410      if context.in_graph_mode():
1411        with self.assertRaisesOpError("uninitialized"):
1412          self.evaluate(v0)
1413        with self.assertRaisesOpError("uninitialized"):
1414          self.evaluate(v1)
1415
1416      # Restore the saved values in the parameter nodes.
1417      save = saver_module.Saver({"save_prefix/v0": v0, "save_prefix/v1": v1})
1418      save.restore(sess, save_path)
1419
1420      # Check that the parameter nodes have been restored.
1421      self.assertEqual(10.0, self.evaluate(v0))
1422      self.assertEqual(20.0, self.evaluate(v1))
1423
1424  @test_util.run_in_graph_and_eager_modes()
1425  def testNonReshapeResourceVariable(self):
1426    self._testNonReshape(resource_variable_ops.ResourceVariable)
1427
1428  def testNonReshapeVariable(self):
1429    self._testNonReshape(variables.Variable)
1430
1431
1432@test_util.with_c_api
1433class LatestCheckpointWithRelativePaths(test.TestCase):
1434
1435  @staticmethod
1436  @contextlib.contextmanager
1437  def tempWorkingDir(temppath):
1438    cwd = os.getcwd()
1439    os.chdir(temppath)
1440    try:
1441      yield
1442    finally:
1443      os.chdir(cwd)
1444
1445  @staticmethod
1446  @contextlib.contextmanager
1447  def tempDir():
1448    tempdir = tempfile.mkdtemp()
1449    try:
1450      yield tempdir
1451    finally:
1452      shutil.rmtree(tempdir)
1453
1454  def testNameCollision(self):
1455    # Make sure we have a clean directory to work in.
1456    with self.tempDir() as tempdir:
1457      # Jump to that directory until this test is done.
1458      with self.tempWorkingDir(tempdir):
1459        # Save training snapshots to a relative path.
1460        traindir = "train/"
1461        os.mkdir(traindir)
1462        # Collides with the default name of the checkpoint state file.
1463        filepath = os.path.join(traindir, "checkpoint")
1464
1465        with self.test_session() as sess:
1466          unused_a = variables.Variable(0.0)  # So that Saver saves something.
1467          variables.global_variables_initializer().run()
1468
1469          # Should fail.
1470          saver = saver_module.Saver(sharded=False)
1471          with self.assertRaisesRegexp(ValueError, "collides with"):
1472            saver.save(sess, filepath)
1473
1474          # Succeeds: the file will be named "checkpoint-<step>".
1475          saver.save(sess, filepath, global_step=1)
1476          self.assertIsNotNone(saver_module.latest_checkpoint(traindir))
1477
1478          # Succeeds: the file will be named "checkpoint-<i>-of-<n>".
1479          saver = saver_module.Saver(sharded=True)
1480          saver.save(sess, filepath)
1481          self.assertIsNotNone(saver_module.latest_checkpoint(traindir))
1482
1483          # Succeeds: the file will be named "checkpoint-<step>-<i>-of-<n>".
1484          saver = saver_module.Saver(sharded=True)
1485          saver.save(sess, filepath, global_step=1)
1486          self.assertIsNotNone(saver_module.latest_checkpoint(traindir))
1487
1488  def testRelativePath(self):
1489    # Make sure we have a clean directory to work in.
1490    with self.tempDir() as tempdir:
1491
1492      # Jump to that directory until this test is done.
1493      with self.tempWorkingDir(tempdir):
1494
1495        # Save training snapshots to a relative path.
1496        traindir = "train/"
1497        os.mkdir(traindir)
1498
1499        filename = "snapshot"
1500        filepath = os.path.join(traindir, filename)
1501
1502        with self.test_session() as sess:
1503          # Build a simple graph.
1504          v0 = variables.Variable(0.0)
1505          inc = v0.assign_add(1.0)
1506
1507          save = saver_module.Saver({"v0": v0})
1508
1509          # Record a short training history.
1510          variables.global_variables_initializer().run()
1511          save.save(sess, filepath, global_step=0)
1512          inc.eval()
1513          save.save(sess, filepath, global_step=1)
1514          inc.eval()
1515          save.save(sess, filepath, global_step=2)
1516
1517        with self.test_session() as sess:
1518          # Build a new graph with different initialization.
1519          v0 = variables.Variable(-1.0)
1520
1521          # Create a new saver.
1522          save = saver_module.Saver({"v0": v0})
1523          variables.global_variables_initializer().run()
1524
1525          # Get the most recent checkpoint name from the training history file.
1526          name = saver_module.latest_checkpoint(traindir)
1527          self.assertIsNotNone(name)
1528
1529          # Restore "v0" from that checkpoint.
1530          save.restore(sess, name)
1531          self.assertEqual(v0.eval(), 2.0)
1532
1533
1534@test_util.with_c_api
1535class CheckpointStateTest(test.TestCase):
1536
1537  def _get_test_dir(self, dirname):
1538    test_dir = os.path.join(self.get_temp_dir(), dirname)
1539    gfile.MakeDirs(test_dir)
1540    return test_dir
1541
1542  def testAbsPath(self):
1543    save_dir = self._get_test_dir("abs_paths")
1544    abs_path = os.path.join(save_dir, "model-0")
1545    ckpt = saver_module.generate_checkpoint_state_proto(save_dir, abs_path)
1546    self.assertEqual(ckpt.model_checkpoint_path, abs_path)
1547    self.assertTrue(os.path.isabs(ckpt.model_checkpoint_path))
1548    self.assertEqual(len(ckpt.all_model_checkpoint_paths), 1)
1549    self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path)
1550
1551  def testRelPath(self):
1552    train_dir = "train"
1553    model = os.path.join(train_dir, "model-0")
1554    # model_checkpoint_path should have no "train" directory part.
1555    new_rel_path = "model-0"
1556    ckpt = saver_module.generate_checkpoint_state_proto(train_dir, model)
1557    self.assertEqual(ckpt.model_checkpoint_path, new_rel_path)
1558    self.assertEqual(len(ckpt.all_model_checkpoint_paths), 1)
1559    self.assertEqual(ckpt.all_model_checkpoint_paths[-1], new_rel_path)
1560
1561  def testAllModelCheckpointPaths(self):
1562    save_dir = self._get_test_dir("all_models_test")
1563    abs_path = os.path.join(save_dir, "model-0")
1564    for paths in [None, [], ["model-2"]]:
1565      ckpt = saver_module.generate_checkpoint_state_proto(
1566          save_dir, abs_path, all_model_checkpoint_paths=paths)
1567      self.assertEqual(ckpt.model_checkpoint_path, abs_path)
1568      self.assertTrue(os.path.isabs(ckpt.model_checkpoint_path))
1569      self.assertEqual(
1570          len(ckpt.all_model_checkpoint_paths), len(paths) if paths else 1)
1571      self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path)
1572
1573  def testUpdateCheckpointState(self):
1574    save_dir = self._get_test_dir("update_checkpoint_state")
1575    os.chdir(save_dir)
1576    # Make a temporary train directory.
1577    train_dir = "train"
1578    os.mkdir(train_dir)
1579    abs_path = os.path.join(save_dir, "model-0")
1580    rel_path = os.path.join("train", "model-2")
1581    saver_module.update_checkpoint_state(
1582        train_dir, rel_path, all_model_checkpoint_paths=[abs_path, rel_path])
1583    ckpt = saver_module.get_checkpoint_state(train_dir)
1584    self.assertEqual(ckpt.model_checkpoint_path, rel_path)
1585    self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2)
1586    self.assertEqual(ckpt.all_model_checkpoint_paths[-1], rel_path)
1587    self.assertEqual(ckpt.all_model_checkpoint_paths[0], abs_path)
1588
1589  def testUpdateCheckpointStateSaveRelativePaths(self):
1590    save_dir = self._get_test_dir("update_checkpoint_state")
1591    os.chdir(save_dir)
1592    abs_path2 = os.path.join(save_dir, "model-2")
1593    rel_path2 = "model-2"
1594    abs_path0 = os.path.join(save_dir, "model-0")
1595    rel_path0 = "model-0"
1596    saver_module._update_checkpoint_state(  # pylint: disable=protected-access
1597        save_dir=save_dir,
1598        model_checkpoint_path=abs_path2,
1599        all_model_checkpoint_paths=[rel_path0, abs_path2],
1600        save_relative_paths=True)
1601
1602    # File should contain relative paths.
1603    file_content = file_io.read_file_to_string(
1604        os.path.join(save_dir, "checkpoint"))
1605    ckpt = CheckpointState()
1606    text_format.Merge(file_content, ckpt)
1607    self.assertEqual(ckpt.model_checkpoint_path, rel_path2)
1608    self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2)
1609    self.assertEqual(ckpt.all_model_checkpoint_paths[-1], rel_path2)
1610    self.assertEqual(ckpt.all_model_checkpoint_paths[0], rel_path0)
1611
1612    # get_checkpoint_state should return absolute paths.
1613    ckpt = saver_module.get_checkpoint_state(save_dir)
1614    self.assertEqual(ckpt.model_checkpoint_path, abs_path2)
1615    self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2)
1616    self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path2)
1617    self.assertEqual(ckpt.all_model_checkpoint_paths[0], abs_path0)
1618
1619  def testCheckPointStateFailsWhenIncomplete(self):
1620    save_dir = self._get_test_dir("checkpoint_state_fails_when_incomplete")
1621    os.chdir(save_dir)
1622    ckpt_path = os.path.join(save_dir, "checkpoint")
1623    ckpt_file = open(ckpt_path, "w")
1624    ckpt_file.write("")
1625    ckpt_file.close()
1626    with self.assertRaises(ValueError):
1627      saver_module.get_checkpoint_state(save_dir)
1628
1629  def testCheckPointCompletesRelativePaths(self):
1630    save_dir = self._get_test_dir("checkpoint_completes_relative_paths")
1631    os.chdir(save_dir)
1632    ckpt_path = os.path.join(save_dir, "checkpoint")
1633    ckpt_file = open(ckpt_path, "w")
1634    ckpt_file.write("""
1635        model_checkpoint_path: "./model.ckpt-687529"
1636        all_model_checkpoint_paths: "./model.ckpt-687500"
1637        all_model_checkpoint_paths: "./model.ckpt-687529"
1638        """)
1639    ckpt_file.close()
1640    ckpt = saver_module.get_checkpoint_state(save_dir)
1641    self.assertEqual(ckpt.model_checkpoint_path,
1642                     os.path.join(save_dir, "./model.ckpt-687529"))
1643    self.assertEqual(ckpt.all_model_checkpoint_paths[0],
1644                     os.path.join(save_dir, "./model.ckpt-687500"))
1645    self.assertEqual(ckpt.all_model_checkpoint_paths[1],
1646                     os.path.join(save_dir, "./model.ckpt-687529"))
1647
1648
1649@test_util.with_c_api
1650class MetaGraphTest(test.TestCase):
1651
1652  def _get_test_dir(self, dirname):
1653    test_dir = os.path.join(self.get_temp_dir(), dirname)
1654    gfile.MakeDirs(test_dir)
1655    return test_dir
1656
1657  def testAddCollectionDef(self):
1658    test_dir = self._get_test_dir("good_collection")
1659    filename = os.path.join(test_dir, "metafile")
1660    with self.test_session():
1661      # Creates a graph.
1662      v0 = variables.Variable(1.0, name="v0")
1663      control_flow_ops.cond(
1664          math_ops.less(v0, 10), lambda: math_ops.add(v0, 1),
1665          lambda: math_ops.subtract(v0, 1))
1666      control_flow_ops.while_loop(lambda i: math_ops.less(i, 10),
1667                                  lambda i: math_ops.add(i, 1), [v0])
1668      var = variables.Variable(constant_op.constant(0, dtype=dtypes.int64))
1669      count_up_to = var.count_up_to(3)
1670      input_queue = data_flow_ops.FIFOQueue(
1671          30, dtypes.float32, shared_name="collection_queue")
1672      qr = queue_runner_impl.QueueRunner(input_queue, [count_up_to])
1673      variables.global_variables_initializer()
1674      # Creates a saver.
1675      save = saver_module.Saver({"v0": v0})
1676      # Adds a set of collections.
1677      ops_lib.add_to_collection("int_collection", 3)
1678      ops_lib.add_to_collection("float_collection", 3.5)
1679      ops_lib.add_to_collection("string_collection", "hello")
1680      ops_lib.add_to_collection("variable_collection", v0)
1681      # Add QueueRunners.
1682      queue_runner_impl.add_queue_runner(qr)
1683      # Adds user_defined proto in three formats: string, bytes and Any.
1684      queue_runner = queue_runner_pb2.QueueRunnerDef(queue_name="test_queue")
1685      ops_lib.add_to_collection("user_defined_string_collection",
1686                                str(queue_runner))
1687      ops_lib.add_to_collection("user_defined_bytes_collection",
1688                                queue_runner.SerializeToString())
1689      any_buf = Any()
1690      any_buf.Pack(queue_runner)
1691      ops_lib.add_to_collection("user_defined_any_collection", any_buf)
1692
1693      # Generates MetaGraphDef.
1694      meta_graph_def = save.export_meta_graph(filename)
1695      self.assertTrue(meta_graph_def.HasField("saver_def"))
1696      self.assertTrue(meta_graph_def.HasField("graph_def"))
1697      self.assertTrue(meta_graph_def.HasField("meta_info_def"))
1698      self.assertNotEqual(meta_graph_def.meta_info_def.tensorflow_version, "")
1699      self.assertNotEqual(meta_graph_def.meta_info_def.tensorflow_git_version,
1700                          "")
1701      collection_def = meta_graph_def.collection_def
1702      self.assertEqual(len(collection_def), 12)
1703
1704    with ops_lib.Graph().as_default():
1705      # Restores from MetaGraphDef.
1706      new_saver = saver_module.import_meta_graph(filename)
1707      # Generates a new MetaGraphDef.
1708      new_meta_graph_def = new_saver.export_meta_graph()
1709      # It should be the same as the original.
1710
1711    test_util.assert_meta_graph_protos_equal(
1712        self, meta_graph_def, new_meta_graph_def)
1713
1714  def testAddCollectionDefFails(self):
1715    with self.test_session():
1716      # Creates a graph.
1717      v0 = variables.Variable(10.0, name="v0")
1718      # Creates a saver.
1719      save = saver_module.Saver({"v0": v0})
1720      # Generates MetaGraphDef.
1721      meta_graph_def = meta_graph_pb2.MetaGraphDef()
1722
1723      # Verifies that collection with unsupported key will not be added.
1724      ops_lib.add_to_collection(save, 3)
1725      save._add_collection_def(meta_graph_def, save)
1726      self.assertEqual(len(meta_graph_def.collection_def), 0)
1727
1728      # Verifies that collection where item type does not match expected
1729      # type will not be added.
1730      ops_lib.add_to_collection("int_collection", 3)
1731      ops_lib.add_to_collection("int_collection", 3.5)
1732      save._add_collection_def(meta_graph_def, "int_collection")
1733      self.assertEqual(len(meta_graph_def.collection_def), 0)
1734
1735  def _testMultiSaverCollectionSave(self, test_dir):
1736    filename = os.path.join(test_dir, "metafile")
1737    saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")
1738    saver1_ckpt = os.path.join(test_dir, "saver1.ckpt")
1739    with self.test_session(graph=ops_lib.Graph()) as sess:
1740      # Creates a graph.
1741      v0 = variables.Variable([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], name="v0")
1742      v1 = variables.Variable(11.0, name="v1")
1743      # Creates 2 savers.
1744      saver0 = saver_module.Saver({"v0": v0}, name="saver0")
1745      saver1 = saver_module.Saver({"v1": v1}, name="saver1")
1746      ops_lib.add_to_collection("savers", saver0)
1747      ops_lib.add_to_collection("savers", saver1)
1748      variables.global_variables_initializer().run()
1749      # Saves to different checkpoints.
1750      saver0.save(sess, saver0_ckpt)
1751      saver1.save(sess, saver1_ckpt)
1752      # Generates MetaGraphDef.
1753      meta_graph_def = saver_module.export_meta_graph(filename)
1754      meta_graph_def0 = saver0.export_meta_graph()
1755      meta_graph_def1 = saver1.export_meta_graph()
1756
1757      # Verifies that there is no saver_def in meta_graph_def.
1758      self.assertFalse(meta_graph_def.HasField("saver_def"))
1759      # Verifies that there is saver_def in meta_graph_def0 and 1.
1760      self.assertTrue(meta_graph_def0.HasField("saver_def"))
1761      self.assertTrue(meta_graph_def1.HasField("saver_def"))
1762
1763      # Verifies SAVERS is saved as bytes_list for meta_graph_def.
1764      collection_def = meta_graph_def.collection_def["savers"]
1765      kind = collection_def.WhichOneof("kind")
1766      self.assertEqual(kind, "bytes_list")
1767      # Verifies that there are 2 entries in SAVERS collection.
1768      savers = getattr(collection_def, kind)
1769      self.assertEqual(2, len(savers.value))
1770
1771      # Verifies SAVERS collection is saved as bytes_list for meta_graph_def0.
1772      collection_def = meta_graph_def0.collection_def["savers"]
1773      kind = collection_def.WhichOneof("kind")
1774      self.assertEqual(kind, "bytes_list")
1775      # Verifies that there are 2 entries in SAVERS collection.
1776      savers = getattr(collection_def, kind)
1777      self.assertEqual(2, len(savers.value))
1778
1779  def _testMultiSaverCollectionRestore(self, test_dir):
1780    filename = os.path.join(test_dir, "metafile")
1781    saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")
1782    saver1_ckpt = os.path.join(test_dir, "saver1.ckpt")
1783    with self.test_session(graph=ops_lib.Graph()) as sess:
1784      # Imports from meta_graph.
1785      saver_module.import_meta_graph(filename)
1786      # Retrieves SAVERS collection. Verifies there are 2 entries.
1787      savers = ops_lib.get_collection("savers")
1788      self.assertEqual(2, len(savers))
1789      # Retrieves saver0. Verifies that new_saver0 can restore v0, but not v1.
1790      new_saver0 = savers[0]
1791      new_saver0.restore(sess, saver0_ckpt)
1792      v0 = sess.graph.get_tensor_by_name("v0:0")
1793      v1 = sess.graph.get_tensor_by_name("v1:0")
1794      self.assertAllEqual([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], v0.eval())
1795      self.assertEqual([3, 2], v0.get_shape())
1796      self.assertEqual([], v1.get_shape())
1797      with self.assertRaisesWithPredicateMatch(
1798          errors_impl.OpError, lambda e: "uninitialized value v1" in e.message):
1799        sess.run(v1)
1800      # Retrieves saver1. Verifies that new_saver1 can restore v1.
1801      new_saver1 = savers[1]
1802      new_saver1.restore(sess, saver1_ckpt)
1803      v1 = sess.graph.get_tensor_by_name("v1:0")
1804      self.assertEqual(11.0, v1.eval())
1805
1806  def testMultiSaverCollection(self):
1807    test_dir = self._get_test_dir("saver_collection")
1808    self._testMultiSaverCollectionSave(test_dir)
1809    self._testMultiSaverCollectionRestore(test_dir)
1810
1811  def testClearExtraneousSavers(self):
1812    test_dir = self._get_test_dir("clear_extraneous_savers")
1813    filename = os.path.join(test_dir, "metafile")
1814    saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")
1815    saver1_ckpt = os.path.join(test_dir, "saver1.ckpt")
1816    with self.test_session(graph=ops_lib.Graph()) as sess:
1817      # Creates a graph.
1818      v0 = variables.Variable([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], name="v0")
1819      v1 = variables.Variable(11.0, name="v1")
1820
1821      # Creates 2 savers.
1822      saver0 = saver_module.Saver({"v0": v0}, name="saver0")
1823      saver1 = saver_module.Saver({"v1": v1}, name="saver1")
1824      ops_lib.add_to_collection("savers", saver0)
1825      ops_lib.add_to_collection("savers", saver1)
1826      variables.global_variables_initializer().run()
1827
1828      # Saves to different checkpoints.
1829      saver0.save(sess, saver0_ckpt)
1830      saver1.save(sess, saver1_ckpt)
1831
1832      # Generates MetaGraphDef.
1833      meta_graph_def = saver_module.export_meta_graph(filename)
1834      meta_graph_def0 = saver0.export_meta_graph()
1835      meta_graph_def1 = saver1.export_meta_graph(clear_extraneous_savers=True)
1836
1837      # Verifies that there is no saver_def in meta_graph_def.
1838      self.assertFalse(meta_graph_def.HasField("saver_def"))
1839      # Verifies that there is saver_def in meta_graph_def0 and 1.
1840      self.assertTrue(meta_graph_def0.HasField("saver_def"))
1841      self.assertTrue(meta_graph_def1.HasField("saver_def"))
1842
1843      # Verifies SAVERS is saved as bytes_list for meta_graph_def.
1844      collection_def = meta_graph_def.collection_def["savers"]
1845      kind = collection_def.WhichOneof("kind")
1846      self.assertEqual(kind, "bytes_list")
1847
1848      # Verifies that there are 2 entries in SAVERS collection.
1849      savers = getattr(collection_def, kind)
1850      self.assertEqual(2, len(savers.value))
1851
1852      # Verifies SAVERS collection is saved as bytes_list for meta_graph_def1.
1853      collection_def = meta_graph_def1.collection_def["savers"]
1854      kind = collection_def.WhichOneof("kind")
1855      self.assertEqual(kind, "bytes_list")
1856
1857      # Verifies that there is 1 entry in SAVERS collection.
1858      savers = getattr(collection_def, kind)
1859      self.assertEqual(1, len(savers.value))
1860
1861      # Verifies that saver0 graph nodes are omitted from the saver1 export
1862      self.assertEqual(29, len(meta_graph_def0.graph_def.node))
1863      self.assertEqual(19, len(meta_graph_def1.graph_def.node))
1864
1865  def testBinaryAndTextFormat(self):
1866    test_dir = self._get_test_dir("binary_and_text")
1867    filename = os.path.join(test_dir, "metafile")
1868    with self.test_session(graph=ops_lib.Graph()):
1869      # Creates a graph.
1870      variables.Variable(10.0, name="v0")
1871      # Exports the graph as binary format.
1872      saver_module.export_meta_graph(filename, as_text=False)
1873    with self.test_session(graph=ops_lib.Graph()):
1874      # Imports the binary format graph.
1875      saver = saver_module.import_meta_graph(filename)
1876      self.assertIsNotNone(saver)
1877      # Exports the graph as text format.
1878      saver.export_meta_graph(filename, as_text=True)
1879    with self.test_session(graph=ops_lib.Graph()):
1880      # Imports the text format graph.
1881      saver_module.import_meta_graph(filename)
1882      # Writes wrong contents to the file.
1883      graph_io.write_graph(saver.as_saver_def(),
1884                           os.path.dirname(filename),
1885                           os.path.basename(filename))
1886    with self.test_session(graph=ops_lib.Graph()):
1887      # Import should fail.
1888      with self.assertRaisesWithPredicateMatch(IOError,
1889                                               lambda e: "Cannot parse file"):
1890        saver_module.import_meta_graph(filename)
1891      # Deletes the file
1892      gfile.Remove(filename)
1893      with self.assertRaisesWithPredicateMatch(IOError,
1894                                               lambda e: "does not exist"):
1895        saver_module.import_meta_graph(filename)
1896
1897  def testSliceVariable(self):
1898    test_dir = self._get_test_dir("slice_saver")
1899    filename = os.path.join(test_dir, "metafile")
1900    with self.test_session():
1901      v1 = variables.Variable([20.0], name="v1")
1902      v2 = variables.Variable([20.0], name="v2")
1903      v2._set_save_slice_info(
1904          variables.Variable.SaveSliceInfo("v1", [1], [0], [1]))
1905
1906      # The names are different and will work.
1907      slice_saver = saver_module.Saver({"first": v1, "second": v2})
1908      variables.global_variables_initializer().run()
1909      # Exports to meta_graph
1910      meta_graph_def = slice_saver.export_meta_graph(filename)
1911
1912    with ops_lib.Graph().as_default():
1913      # Restores from MetaGraphDef.
1914      new_saver = saver_module.import_meta_graph(filename)
1915      self.assertIsNotNone(new_saver)
1916      # Generates a new MetaGraphDef.
1917      new_meta_graph_def = new_saver.export_meta_graph()
1918      # It should be the same as the original.
1919      test_util.assert_meta_graph_protos_equal(self, meta_graph_def,
1920                                               new_meta_graph_def)
1921
1922  def _testGraphExtensionSave(self, test_dir):
1923    filename = os.path.join(test_dir, "metafile")
1924    saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")
1925    # Creates an inference graph.
1926    # Hidden 1
1927    images = constant_op.constant(1.2, dtypes.float32, shape=[100, 28])
1928    with ops_lib.name_scope("hidden1"):
1929      weights = variables.Variable(
1930          random_ops.truncated_normal(
1931              [28, 128], stddev=1.0 / math.sqrt(float(28))),
1932          name="weights")
1933      # The use of control_flow_ops.cond here is purely for adding test coverage
1934      # the save and restore of control flow context (which doesn't make any
1935      # sense here from a machine learning perspective).  The typical biases is
1936      # a simple Variable without the conditions.
1937      biases = variables.Variable(
1938          control_flow_ops.cond(
1939              math_ops.less(random.random(), 0.5),
1940              lambda: array_ops.ones([128]), lambda: array_ops.zeros([128])),
1941          name="biases")
1942      hidden1 = nn_ops.relu(math_ops.matmul(images, weights) + biases)
1943    # Hidden 2
1944    with ops_lib.name_scope("hidden2"):
1945      weights = variables.Variable(
1946          random_ops.truncated_normal(
1947              [128, 32], stddev=1.0 / math.sqrt(float(128))),
1948          name="weights")
1949
1950      # The use of control_flow_ops.while_loop here is purely for adding test
1951      # coverage the save and restore of control flow context (which doesn't
1952      # make any sense here from a machine learning perspective).  The typical
1953      # biases is a simple Variable without the conditions.
1954      def loop_cond(it, _):
1955        return it < 2
1956
1957      def loop_body(it, biases):
1958        biases += constant_op.constant(0.1, shape=[32])
1959        return it + 1, biases
1960
1961      _, biases = control_flow_ops.while_loop(
1962          loop_cond, loop_body,
1963          [constant_op.constant(0), variables.Variable(array_ops.zeros([32]))])
1964      hidden2 = nn_ops.relu(math_ops.matmul(hidden1, weights) + biases)
1965    # Linear
1966    with ops_lib.name_scope("softmax_linear"):
1967      weights = variables.Variable(
1968          random_ops.truncated_normal(
1969              [32, 10], stddev=1.0 / math.sqrt(float(32))),
1970          name="weights")
1971      biases = variables.Variable(array_ops.zeros([10]), name="biases")
1972      logits = math_ops.matmul(hidden2, weights) + biases
1973      ops_lib.add_to_collection("logits", logits)
1974    init_all_op = variables.global_variables_initializer()
1975
1976    with self.test_session() as sess:
1977      # Initializes all the variables.
1978      sess.run(init_all_op)
1979      # Runs to logit.
1980      sess.run(logits)
1981      # Creates a saver.
1982      saver0 = saver_module.Saver()
1983      saver0.save(sess, saver0_ckpt)
1984      # Generates MetaGraphDef.
1985      saver0.export_meta_graph(filename)
1986
1987  def _testGraphExtensionRestore(self, test_dir):
1988    filename = os.path.join(test_dir, "metafile")
1989    train_filename = os.path.join(test_dir, "train_metafile")
1990    saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")
1991    with self.test_session(graph=ops_lib.Graph()) as sess:
1992      # Restores from MetaGraphDef.
1993      new_saver = saver_module.import_meta_graph(filename)
1994      # Generates a new MetaGraphDef.
1995      new_saver.export_meta_graph()
1996      # Restores from checkpoint.
1997      new_saver.restore(sess, saver0_ckpt)
1998      # Adds loss and train.
1999      labels = constant_op.constant(0, dtypes.int32, shape=[100], name="labels")
2000      batch_size = array_ops.size(labels)
2001      labels = array_ops.expand_dims(labels, 1)
2002      indices = array_ops.expand_dims(math_ops.range(0, batch_size), 1)
2003      concated = array_ops.concat([indices, labels], 1)
2004      onehot_labels = sparse_ops.sparse_to_dense(
2005          concated, array_ops.stack([batch_size, 10]), 1.0, 0.0)
2006      logits = ops_lib.get_collection("logits")[0]
2007      cross_entropy = nn_ops.softmax_cross_entropy_with_logits(
2008          labels=onehot_labels, logits=logits, name="xentropy")
2009      loss = math_ops.reduce_mean(cross_entropy, name="xentropy_mean")
2010
2011      summary.scalar("loss", loss)
2012      # Creates the gradient descent optimizer with the given learning rate.
2013      optimizer = gradient_descent.GradientDescentOptimizer(0.01)
2014
2015      # Runs train_op.
2016      train_op = optimizer.minimize(loss)
2017      ops_lib.add_to_collection("train_op", train_op)
2018
2019      # Runs train_op.
2020      sess.run(train_op)
2021
2022      # Generates MetaGraphDef.
2023      saver_module.export_meta_graph(train_filename)
2024
2025  def _testRestoreFromTrainGraphWithControlContext(self, test_dir):
2026    train_filename = os.path.join(test_dir, "train_metafile")
2027    saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")
2028    with self.test_session(graph=ops_lib.Graph()) as sess:
2029      # Restores from MetaGraphDef.
2030      new_saver = saver_module.import_meta_graph(train_filename)
2031      # Restores from checkpoint.
2032      new_saver.restore(sess, saver0_ckpt)
2033      train_op = ops_lib.get_collection("train_op")[0]
2034      sess.run(train_op)
2035
2036  def testGraphExtension(self):
2037    test_dir = self._get_test_dir("graph_extension")
2038    self._testGraphExtensionSave(test_dir)
2039    self._testGraphExtensionRestore(test_dir)
2040    self._testRestoreFromTrainGraphWithControlContext(test_dir)
2041
2042  def testStrippedOpListDef(self):
2043    with self.test_session():
2044      # Creates a graph.
2045      v0 = variables.Variable(0.0)
2046      var = variables.Variable(10.0)
2047      math_ops.add(v0, var)
2048
2049      @function.Defun(dtypes.float32)
2050      def minus_one(x):
2051        return x - 1
2052
2053      minus_one(array_ops.identity(v0))
2054      save = saver_module.Saver({"v0": v0})
2055      variables.global_variables_initializer()
2056
2057      # Generates MetaGraphDef.
2058      meta_graph_def = save.export_meta_graph()
2059      ops = [o.name for o in meta_graph_def.meta_info_def.stripped_op_list.op]
2060      if save._write_version is saver_pb2.SaverDef.V1:
2061        self.assertEqual(ops, [
2062            "Add", "Assign", "Const", "Identity", "NoOp", "RestoreV2",
2063            "SaveSlices", "Sub", "VariableV2"
2064        ])
2065      else:
2066        self.assertEqual(ops, [
2067            "Add", "Assign", "Const", "Identity", "NoOp", "RestoreV2", "SaveV2",
2068            "Sub", "VariableV2"
2069        ])
2070
2071      # Test calling stripped_op_list_for_graph directly
2072      op_list = meta_graph.stripped_op_list_for_graph(meta_graph_def.graph_def)
2073      self.assertEqual(ops, [o.name for o in op_list.op])
2074      for o in op_list.op:
2075        self.assertEqual(o.summary, "")
2076        self.assertEqual(o.description, "")
2077
2078  def testStripDefaultValuedAttrs(self):
2079    """Verifies that default valued attrs are stripped, unless disabled."""
2080
2081    # With strip_default_attrs enabled, attributes "T" (float32) and "Tout"
2082    # (complex64) in the "Complex" op must be removed.
2083    with self.test_session():
2084      real_num = variables.Variable(1.0, dtype=dtypes.float32, name="real")
2085      imag_num = variables.Variable(2.0, dtype=dtypes.float32, name="imag")
2086      math_ops.complex(real_num, imag_num, name="complex")
2087
2088      save = saver_module.Saver({"real_num": real_num, "imag_num": imag_num})
2089      variables.global_variables_initializer()
2090
2091      meta_graph_def = save.export_meta_graph(strip_default_attrs=True)
2092      node_def = test_util.get_node_def_from_graph("complex",
2093                                                   meta_graph_def.graph_def)
2094      self.assertNotIn("T", node_def.attr)
2095      self.assertNotIn("Tout", node_def.attr)
2096
2097    # With strip_default_attrs disabled, attributes "T" (float32) and "Tout"
2098    # (complex64) in the "Complex" op must *not* be removed, even if they map
2099    # to their defaults.
2100    with self.test_session(graph=ops_lib.Graph()):
2101      real_num = variables.Variable(1.0, dtype=dtypes.float32, name="real")
2102      imag_num = variables.Variable(2.0, dtype=dtypes.float32, name="imag")
2103      math_ops.complex(real_num, imag_num, name="complex")
2104
2105      save = saver_module.Saver({"real_num": real_num, "imag_num": imag_num})
2106      variables.global_variables_initializer()
2107
2108      meta_graph_def = save.export_meta_graph(strip_default_attrs=False)
2109      node_def = test_util.get_node_def_from_graph("complex",
2110                                                   meta_graph_def.graph_def)
2111      self.assertIn("T", node_def.attr)
2112      self.assertIn("Tout", node_def.attr)
2113
2114  def testImportIntoNamescope(self):
2115    # Test that we can import a meta graph into a namescope.
2116    test_dir = self._get_test_dir("import_into_namescope")
2117    filename = os.path.join(test_dir, "ckpt")
2118    image = array_ops.placeholder(dtypes.float32, [None, 784], name="image")
2119    label = array_ops.placeholder(dtypes.float32, [None, 10], name="label")
2120    with session.Session() as sess:
2121      weights = variables.Variable(
2122          random_ops.random_uniform([784, 10]), name="weights")
2123      bias = variables.Variable(array_ops.zeros([10]), name="bias")
2124      logit = nn_ops.relu(math_ops.matmul(image, weights) + bias, name="logits")
2125      nn_ops.softmax(logit, name="prediction")
2126      cost = nn_ops.softmax_cross_entropy_with_logits(labels=label,
2127                                                      logits=logit, name="cost")
2128      adam.AdamOptimizer().minimize(cost, name="optimize")
2129      saver = saver_module.Saver()
2130      sess.run(variables.global_variables_initializer())
2131      saver.save(sess, filename)
2132
2133    graph = ops_lib.Graph()
2134    with session.Session(graph=graph) as sess:
2135      new_saver = saver_module.import_meta_graph(
2136          filename + ".meta", graph=graph, import_scope="new_model")
2137      new_saver.restore(sess, filename)
2138      sess.run(["new_model/optimize"], {
2139          "new_model/image:0": np.random.random([1, 784]),
2140          "new_model/label:0": np.random.randint(
2141              10, size=[1, 10])
2142      })
2143
2144  def testClearDevicesOnImport(self):
2145    # Test that we import a graph without its devices and run successfully.
2146    with ops_lib.Graph().as_default():
2147      with ops_lib.device("/job:ps/replica:0/task:0/device:GPU:0"):
2148        image = array_ops.placeholder(dtypes.float32, [None, 784], name="image")
2149        label = array_ops.placeholder(dtypes.float32, [None, 10], name="label")
2150        weights = variables.Variable(
2151            random_ops.random_uniform([784, 10]), name="weights")
2152        bias = variables.Variable(array_ops.zeros([10]), name="bias")
2153        logit = nn_ops.relu(math_ops.matmul(image, weights) + bias)
2154        nn_ops.softmax(logit, name="prediction")
2155        cost = nn_ops.softmax_cross_entropy_with_logits(labels=label,
2156                                                        logits=logit)
2157        adam.AdamOptimizer().minimize(cost, name="optimize")
2158      meta_graph_def = saver_module.export_meta_graph()
2159
2160    with session.Session(graph=ops_lib.Graph()) as sess:
2161      saver_module.import_meta_graph(
2162          meta_graph_def, clear_devices=False, import_scope="new_model")
2163      # Device refers to GPU, which is not available here.
2164      with self.assertRaises(errors_impl.InvalidArgumentError):
2165        sess.run(variables.global_variables_initializer())
2166
2167    with session.Session(graph=ops_lib.Graph()) as sess:
2168      saver_module.import_meta_graph(
2169          meta_graph_def, clear_devices=True, import_scope="new_model")
2170      sess.run(variables.global_variables_initializer())
2171      sess.run(["new_model/optimize"], {
2172          "new_model/image:0": np.random.random([1, 784]),
2173          "new_model/label:0": np.random.randint(
2174              10, size=[1, 10])
2175      })
2176
2177  def testClearDevicesOnExport(self):
2178    # Test that we export a graph without its devices and run successfully.
2179    with ops_lib.Graph().as_default():
2180      with ops_lib.device("/job:ps/replica:0/task:0/device:GPU:0"):
2181        image = array_ops.placeholder(dtypes.float32, [None, 784], name="image")
2182        label = array_ops.placeholder(dtypes.float32, [None, 10], name="label")
2183        weights = variables.Variable(
2184            random_ops.random_uniform([784, 10]), name="weights")
2185        bias = variables.Variable(array_ops.zeros([10]), name="bias")
2186        logit = nn_ops.relu(math_ops.matmul(image, weights) + bias)
2187        nn_ops.softmax(logit, name="prediction")
2188        cost = nn_ops.softmax_cross_entropy_with_logits(labels=label,
2189                                                        logits=logit)
2190        adam.AdamOptimizer().minimize(cost, name="optimize")
2191      meta_graph_def = saver_module.export_meta_graph(clear_devices=True)
2192      graph_io.write_graph(meta_graph_def, self.get_temp_dir(),
2193                           "meta_graph.pbtxt")
2194
2195    with session.Session(graph=ops_lib.Graph()) as sess:
2196      saver_module.import_meta_graph(meta_graph_def, import_scope="new_model")
2197      sess.run(variables.global_variables_initializer())
2198      sess.run(["new_model/optimize"], {
2199          "new_model/image:0": np.random.random([1, 784]),
2200          "new_model/label:0": np.random.randint(
2201              10, size=[1, 10])
2202      })
2203
2204  def testPreserveDatasetAndFunctions(self):
2205    with ops_lib.Graph().as_default() as g:
2206      dataset = dataset_ops.Dataset.range(10).map(lambda x: x * x)
2207      iterator = dataset.make_one_shot_iterator()
2208      next_element = iterator.get_next()
2209      _ = array_ops.identity(next_element, name="output")
2210
2211      # Generate three MetaGraphDef protos using different code paths.
2212      meta_graph_def_simple = saver_module.export_meta_graph()
2213      meta_graph_def_devices_cleared = saver_module.export_meta_graph(
2214          clear_devices=True)
2215      meta_graph_def_from_graph_def = saver_module.export_meta_graph(
2216          clear_devices=True, graph_def=g.as_graph_def())
2217
2218    for meta_graph_def in [meta_graph_def_simple,
2219                           meta_graph_def_devices_cleared,
2220                           meta_graph_def_from_graph_def]:
2221      with session.Session(graph=ops_lib.Graph()) as sess:
2222        saver_module.import_meta_graph(meta_graph_def, import_scope="new_model")
2223        sess.run(variables.global_variables_initializer())
2224        for i in range(10):
2225          self.assertEqual(i * i, sess.run("new_model/output:0"))
2226        with self.assertRaises(errors.OutOfRangeError):
2227          sess.run("new_model/output:0")
2228
2229
2230@test_util.with_c_api
2231class CheckpointReaderTest(test.TestCase):
2232
2233  _WRITE_VERSION = saver_pb2.SaverDef.V1
2234
2235  def testDebugString(self):
2236    # Builds a graph.
2237    v0 = variables.Variable(
2238        [[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32, name="v0")
2239    v1 = variables.Variable(
2240        [[[1], [2]], [[3], [4]], [[5], [6]]], dtype=dtypes.float32, name="v1")
2241    init_all_op = variables.global_variables_initializer()
2242    save = saver_module.Saver(
2243        {
2244            "v0": v0,
2245            "v1": v1
2246        }, write_version=self._WRITE_VERSION)
2247    save_path = os.path.join(self.get_temp_dir(),
2248                             "ckpt_for_debug_string" + str(self._WRITE_VERSION))
2249    with self.test_session() as sess:
2250      sess.run(init_all_op)
2251      # Saves a checkpoint.
2252      save.save(sess, save_path)
2253
2254      # Creates a reader.
2255      reader = pywrap_tensorflow.NewCheckpointReader(save_path)
2256      # Verifies that the tensors exist.
2257      self.assertTrue(reader.has_tensor("v0"))
2258      self.assertTrue(reader.has_tensor("v1"))
2259      debug_string = reader.debug_string()
2260      # Verifies that debug string contains the right strings.
2261      self.assertTrue(compat.as_bytes("v0 (DT_FLOAT) [2,3]") in debug_string)
2262      self.assertTrue(compat.as_bytes("v1 (DT_FLOAT) [3,2,1]") in debug_string)
2263      # Verifies get_variable_to_shape_map() returns the correct information.
2264      var_map = reader.get_variable_to_shape_map()
2265      self.assertEqual([2, 3], var_map["v0"])
2266      self.assertEqual([3, 2, 1], var_map["v1"])
2267      # Verifies get_tensor() returns the tensor value.
2268      v0_tensor = reader.get_tensor("v0")
2269      v1_tensor = reader.get_tensor("v1")
2270      self.assertAllEqual(v0.eval(), v0_tensor)
2271      self.assertAllEqual(v1.eval(), v1_tensor)
2272      # Verifies get_tensor() fails for non-existent tensors.
2273      with self.assertRaisesRegexp(errors.NotFoundError,
2274                                   "v3 not found in checkpoint"):
2275        reader.get_tensor("v3")
2276
2277  def testNonexistentPath(self):
2278    with self.assertRaisesRegexp(errors.NotFoundError,
2279                                 "Unsuccessful TensorSliceReader"):
2280      pywrap_tensorflow.NewCheckpointReader("non-existent")
2281
2282
2283@test_util.with_c_api
2284class CheckpointReaderForV2Test(CheckpointReaderTest):
2285  _WRITE_VERSION = saver_pb2.SaverDef.V2
2286
2287
2288@test_util.with_c_api
2289class WriteGraphTest(test.TestCase):
2290
2291  def _get_test_dir(self, dirname):
2292    test_dir = os.path.join(self.get_temp_dir(), dirname)
2293    gfile.MakeDirs(test_dir)
2294    return test_dir
2295
2296  def testWriteGraph(self):
2297    test_dir = self._get_test_dir("write_graph_dir")
2298    variables.Variable([[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32, name="v0")
2299    path = graph_io.write_graph(ops_lib.get_default_graph(),
2300                                os.path.join(test_dir, "l1"), "graph.pbtxt")
2301    truth = os.path.join(test_dir, "l1", "graph.pbtxt")
2302    self.assertEqual(path, truth)
2303    self.assertTrue(os.path.exists(path))
2304
2305  def testRecursiveCreate(self):
2306    test_dir = self._get_test_dir("deep_dir")
2307    variables.Variable([[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32, name="v0")
2308    path = graph_io.write_graph(ops_lib.get_default_graph().as_graph_def(),
2309                                os.path.join(test_dir, "l1", "l2", "l3"),
2310                                "graph.pbtxt")
2311    truth = os.path.join(test_dir, "l1", "l2", "l3", "graph.pbtxt")
2312    self.assertEqual(path, truth)
2313    self.assertTrue(os.path.exists(path))
2314
2315
2316@test_util.with_c_api
2317class SaverUtilsTest(test.TestCase):
2318
2319  def setUp(self):
2320    self._base_dir = os.path.join(self.get_temp_dir(), "saver_utils_test")
2321    gfile.MakeDirs(self._base_dir)
2322
2323  def tearDown(self):
2324    gfile.DeleteRecursively(self._base_dir)
2325
2326  def testCheckpointExists(self):
2327    for sharded in (False, True):
2328      for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1):
2329        with self.test_session(graph=ops_lib.Graph()) as sess:
2330          unused_v = variables.Variable(1.0, name="v")
2331          variables.global_variables_initializer().run()
2332          saver = saver_module.Saver(sharded=sharded, write_version=version)
2333
2334          path = os.path.join(self._base_dir, "%s-%s" % (sharded, version))
2335          self.assertFalse(
2336              saver_module.checkpoint_exists(path))  # Not saved yet.
2337
2338          ckpt_prefix = saver.save(sess, path)
2339          self.assertTrue(saver_module.checkpoint_exists(ckpt_prefix))
2340
2341          ckpt_prefix = saver_module.latest_checkpoint(self._base_dir)
2342          self.assertTrue(saver_module.checkpoint_exists(ckpt_prefix))
2343
2344  def testGetCheckpointMtimes(self):
2345    prefixes = []
2346    for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1):
2347      with self.test_session(graph=ops_lib.Graph()) as sess:
2348        unused_v = variables.Variable(1.0, name="v")
2349        variables.global_variables_initializer().run()
2350        saver = saver_module.Saver(write_version=version)
2351        prefixes.append(
2352            saver.save(sess, os.path.join(self._base_dir, str(version))))
2353
2354    mtimes = saver_module.get_checkpoint_mtimes(prefixes)
2355    self.assertEqual(2, len(mtimes))
2356    self.assertTrue(mtimes[1] >= mtimes[0])
2357
2358
2359@test_util.with_c_api
2360class ScopedGraphTest(test.TestCase):
2361
2362  def _get_test_dir(self, dirname):
2363    test_dir = os.path.join(self.get_temp_dir(), dirname)
2364    gfile.MakeDirs(test_dir)
2365    return test_dir
2366
2367  def _testScopedSave(self, test_dir, exported_filename, ckpt_filename):
2368    graph = ops_lib.Graph()
2369    with graph.as_default():
2370      # Creates an inference graph.
2371      # Hidden 1
2372      images = constant_op.constant(
2373          1.2, dtypes.float32, shape=[100, 28], name="images")
2374      with ops_lib.name_scope("hidden1"):
2375        weights1 = variables.Variable(
2376            random_ops.truncated_normal(
2377                [28, 128], stddev=1.0 / math.sqrt(float(28))),
2378            name="weights")
2379        # The use of control_flow_ops.cond here is purely for adding test
2380        # coverage the save and restore of control flow context (which doesn't
2381        # make any sense here from a machine learning perspective).  The typical
2382        # biases is a simple Variable without the conditions.
2383        biases1 = variables.Variable(
2384            control_flow_ops.cond(
2385                math_ops.less(random.random(), 0.5),
2386                lambda: array_ops.ones([128]), lambda: array_ops.zeros([128])),
2387            name="biases")
2388        hidden1 = nn_ops.relu(math_ops.matmul(images, weights1) + biases1)
2389
2390      # Hidden 2
2391      with ops_lib.name_scope("hidden2"):
2392        weights2 = variables.Variable(
2393            random_ops.truncated_normal(
2394                [128, 32], stddev=1.0 / math.sqrt(float(128))),
2395            name="weights")
2396
2397        # The use of control_flow_ops.while_loop here is purely for adding test
2398        # coverage the save and restore of control flow context (which doesn't
2399        # make any sense here from a machine learning perspective).  The typical
2400        # biases is a simple Variable without the conditions.
2401        def loop_cond(it, _):
2402          return it < 2
2403
2404        def loop_body(it, biases2):
2405          biases2 += constant_op.constant(0.1, shape=[32])
2406          return it + 1, biases2
2407
2408        _, biases2 = control_flow_ops.while_loop(loop_cond, loop_body, [
2409            constant_op.constant(0), variables.Variable(array_ops.zeros([32]))
2410        ])
2411        hidden2 = nn_ops.relu(math_ops.matmul(hidden1, weights2) + biases2)
2412      # Linear
2413      with ops_lib.name_scope("softmax_linear"):
2414        weights3 = variables.Variable(
2415            random_ops.truncated_normal(
2416                [32, 10], stddev=1.0 / math.sqrt(float(32))),
2417            name="weights")
2418        biases3 = variables.Variable(array_ops.zeros([10]), name="biases")
2419        logits = math_ops.matmul(hidden2, weights3) + biases3
2420        ops_lib.add_to_collection("logits", logits)
2421
2422        # Adds user_defined proto in three formats: string, bytes and Any.
2423        # Any proto should just pass through.
2424        queue_runner = queue_runner_pb2.QueueRunnerDef(queue_name="test_queue")
2425        ops_lib.add_to_collection("user_defined_string_collection",
2426                                  str(queue_runner))
2427        ops_lib.add_to_collection("user_defined_bytes_collection",
2428                                  queue_runner.SerializeToString())
2429        any_buf = Any()
2430        any_buf.Pack(queue_runner)
2431        ops_lib.add_to_collection("user_defined_any_collection", any_buf)
2432
2433      _, var_list = meta_graph.export_scoped_meta_graph(
2434          filename=os.path.join(test_dir, exported_filename),
2435          graph=ops_lib.get_default_graph(),
2436          export_scope="hidden1")
2437      self.assertEqual(["biases:0", "weights:0"], sorted(var_list.keys()))
2438
2439    with self.test_session(graph=graph) as sess:
2440      sess.run(variables.global_variables_initializer())
2441      saver = saver_module.Saver(var_list=var_list, max_to_keep=1)
2442      saver.save(sess, os.path.join(test_dir, ckpt_filename), write_state=False)
2443
2444  def _testScopedRestore(self, test_dir, exported_filename,
2445                         new_exported_filename, ckpt_filename):
2446    graph = ops_lib.Graph()
2447    # Create all the missing inputs.
2448    with graph.as_default():
2449      new_image = constant_op.constant(
2450          1.2, dtypes.float32, shape=[100, 28], name="images")
2451    var_list = meta_graph.import_scoped_meta_graph(
2452        os.path.join(test_dir, exported_filename),
2453        graph=graph,
2454        input_map={"$unbound_inputs_images": new_image},
2455        import_scope="new_hidden1")
2456    self.assertEqual(["biases:0", "weights:0"], sorted(var_list.keys()))
2457    hidden1 = graph.as_graph_element("new_hidden1/Relu:0")
2458    weights1 = graph.as_graph_element("new_hidden1/weights:0")
2459    biases1 = graph.as_graph_element("new_hidden1/biases:0")
2460
2461    with graph.as_default():
2462      # Hidden 2
2463      with ops_lib.name_scope("hidden2"):
2464        weights = variables.Variable(
2465            random_ops.truncated_normal(
2466                [128, 32], stddev=1.0 / math.sqrt(float(128))),
2467            name="weights")
2468
2469        # The use of control_flow_ops.while_loop here is purely for adding test
2470        # coverage the save and restore of control flow context (which doesn't
2471        # make any sense here from a machine learning perspective).  The typical
2472        # biases is a simple Variable without the conditions.
2473        def loop_cond(it, _):
2474          return it < 2
2475
2476        def loop_body(it, biases):
2477          biases += constant_op.constant(0.1, shape=[32])
2478          return it + 1, biases
2479
2480        _, biases = control_flow_ops.while_loop(loop_cond, loop_body, [
2481            constant_op.constant(0), variables.Variable(array_ops.zeros([32]))
2482        ])
2483        hidden2 = nn_ops.relu(math_ops.matmul(hidden1, weights) + biases)
2484      # Linear
2485      with ops_lib.name_scope("softmax_linear"):
2486        weights = variables.Variable(
2487            random_ops.truncated_normal(
2488                [32, 10], stddev=1.0 / math.sqrt(float(32))),
2489            name="weights")
2490        biases = variables.Variable(array_ops.zeros([10]), name="biases")
2491        logits = math_ops.matmul(hidden2, weights) + biases
2492        ops_lib.add_to_collection("logits", logits)
2493
2494      # The rest of the variables.
2495      rest_variables = list(
2496          set(variables.global_variables()) - set(var_list.keys()))
2497      init_rest_op = variables.initialize_variables(rest_variables)
2498
2499    with self.test_session(graph=graph) as sess:
2500      saver = saver_module.Saver(var_list=var_list, max_to_keep=1)
2501      saver.restore(sess, os.path.join(test_dir, ckpt_filename))
2502      # Verify that we have restored weights1 and biases1.
2503      sess.run([weights1, biases1])
2504      # Initialize the rest of the variables and run logits.
2505      sess.run(init_rest_op)
2506      sess.run(logits)
2507
2508  # Verifies that we can save the subgraph under "hidden1" and restore it
2509  # into "new_hidden1" in the new graph.
2510  def testScopedSaveAndRestore(self):
2511    test_dir = self._get_test_dir("scoped_export_import")
2512    ckpt_filename = "ckpt"
2513    self._testScopedSave(test_dir, "exported_hidden1.pbtxt", ckpt_filename)
2514    self._testScopedRestore(test_dir, "exported_hidden1.pbtxt",
2515                            "exported_new_hidden1.pbtxt", ckpt_filename)
2516
2517  # Verifies that we can copy the subgraph under "hidden1" and copy it
2518  # to different name scope in the same graph or different graph.
2519  def testCopyScopedGraph(self):
2520    test_dir = self._get_test_dir("scoped_copy")
2521    saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")
2522    graph1 = ops_lib.Graph()
2523    with graph1.as_default():
2524      with ops_lib.name_scope("hidden1"):
2525        images = constant_op.constant(
2526            1.0, dtypes.float32, shape=[3, 2], name="images")
2527        weights1 = variables.Variable(
2528            [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], name="weights")
2529        biases1 = variables.Variable([0.1] * 3, name="biases")
2530        nn_ops.relu(math_ops.matmul(images, weights1) + biases1, name="relu")
2531
2532    # Run the graph and save scoped checkpoint.
2533    with self.test_session(graph=graph1) as sess:
2534      sess.run(variables.global_variables_initializer())
2535      _, var_list_1 = meta_graph.export_scoped_meta_graph(
2536          export_scope="hidden1")
2537      saver = saver_module.Saver(var_list=var_list_1, max_to_keep=1)
2538      saver.save(sess, saver0_ckpt, write_state=False)
2539
2540    expected = np.reshape([[5.0999999, 7.0999999, 9.10000038] * 3], (3, 3))
2541
2542    # Verifies copy to the same graph with the same name fails.
2543    with graph1.as_default():
2544      with self.assertRaisesWithPredicateMatch(
2545          ValueError, lambda e: "need to be different" in str(e)):
2546        meta_graph.copy_scoped_meta_graph(
2547            from_scope="hidden1", to_scope="hidden1")
2548
2549    # Verifies copy to the same graph.
2550    with graph1.as_default():
2551      var_list_2 = meta_graph.copy_scoped_meta_graph(
2552          from_scope="hidden1", to_scope="hidden2")
2553
2554    with self.test_session(graph=graph1) as sess:
2555      saver1 = saver_module.Saver(var_list=var_list_1, max_to_keep=1)
2556      saver1.restore(sess, saver0_ckpt)
2557      saver2 = saver_module.Saver(var_list=var_list_2, max_to_keep=1)
2558      saver2.restore(sess, saver0_ckpt)
2559      self.assertAllClose(expected, sess.run("hidden1/relu:0"))
2560      self.assertAllClose(expected, sess.run("hidden2/relu:0"))
2561
2562    # Verifies copy to differen graph.
2563    graph2 = ops_lib.Graph()
2564    new_var_list_1 = meta_graph.copy_scoped_meta_graph(
2565        from_scope="hidden1",
2566        to_scope="new_hidden1",
2567        from_graph=graph1,
2568        to_graph=graph2)
2569
2570    with self.test_session(graph=graph2) as sess:
2571      saver3 = saver_module.Saver(var_list=new_var_list_1, max_to_keep=1)
2572      saver3.restore(sess, saver0_ckpt)
2573      self.assertAllClose(expected, sess.run("new_hidden1/relu:0"))
2574
2575  def testExportGraphDefWithScope(self):
2576    test_dir = self._get_test_dir("export_graph_def")
2577    saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")
2578    graph1 = ops_lib.Graph()
2579    with graph1.as_default():
2580      with ops_lib.name_scope("hidden1"):
2581        images = constant_op.constant(
2582            1.0, dtypes.float32, shape=[3, 2], name="images")
2583        weights1 = variables.Variable(
2584            [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], name="weights")
2585        biases1 = variables.Variable([0.1] * 3, name="biases")
2586        nn_ops.relu(math_ops.matmul(images, weights1) + biases1, name="relu")
2587
2588    # Run the graph and save scoped checkpoint.
2589    with self.test_session(graph=graph1) as sess:
2590      sess.run(variables.global_variables_initializer())
2591      _, var_list_1 = meta_graph.export_scoped_meta_graph(
2592          graph_def=graph1.as_graph_def(), export_scope="hidden1")
2593      saver = saver_module.Saver(var_list=var_list_1, max_to_keep=1)
2594      saver.save(sess, saver0_ckpt, write_state=False)
2595
2596    expected = np.reshape([[5.0999999, 7.0999999, 9.10000038] * 3], (3, 3))
2597
2598    # Verifies that we can run successfully after restoring.
2599    graph2 = ops_lib.Graph()
2600    new_var_list_1 = meta_graph.copy_scoped_meta_graph(
2601        from_scope="hidden1",
2602        to_scope="new_hidden1",
2603        from_graph=graph1,
2604        to_graph=graph2)
2605
2606    with self.test_session(graph=graph2) as sess:
2607      saver3 = saver_module.Saver(var_list=new_var_list_1, max_to_keep=1)
2608      saver3.restore(sess, saver0_ckpt)
2609      self.assertAllClose(expected, sess.run("new_hidden1/relu:0"))
2610
2611  def testSerializeSaverWithScope(self):
2612    test_dir = self._get_test_dir("export_graph_def")
2613    saver1_ckpt = os.path.join(test_dir, "saver1.ckpt")
2614    saver2_ckpt = os.path.join(test_dir, "saver2.ckpt")
2615    graph = ops_lib.Graph()
2616    with graph.as_default():
2617      with ops_lib.name_scope("hidden1"):
2618        variable1 = variables.Variable([1.0], name="variable1")
2619        saver1 = saver_module.Saver(var_list=[variable1])
2620        graph.add_to_collection(ops_lib.GraphKeys.SAVERS, saver1)
2621
2622      with ops_lib.name_scope("hidden2"):
2623        variable2 = variables.Variable([2.0], name="variable2")
2624      saver2 = saver_module.Saver(var_list=[variable2], name="hidden2/")
2625      graph.add_to_collection(ops_lib.GraphKeys.SAVERS, saver2)
2626
2627    with self.test_session(graph=graph) as sess:
2628      variables.global_variables_initializer().run()
2629      saver1.save(sess, saver1_ckpt, write_state=False)
2630      saver2.save(sess, saver2_ckpt, write_state=False)
2631
2632    graph1 = ops_lib.Graph()
2633    var_dict1 = meta_graph.copy_scoped_meta_graph(
2634        from_scope="hidden1",
2635        to_scope="new_hidden1",
2636        from_graph=graph,
2637        to_graph=graph1)
2638    self.assertEqual(1, len(var_dict1))
2639
2640    saver_list1 = graph1.get_collection(ops_lib.GraphKeys.SAVERS)
2641    self.assertEqual(1, len(saver_list1))
2642
2643    with self.test_session(graph=graph1) as sess:
2644      saver_list1[0].restore(sess, saver1_ckpt)
2645      self.assertEqual(1.0, var_dict1["variable1:0"].eval())
2646
2647    graph2 = ops_lib.Graph()
2648    var_dict2 = meta_graph.copy_scoped_meta_graph(
2649        from_scope="hidden2",
2650        to_scope="new_hidden2",
2651        from_graph=graph,
2652        to_graph=graph2)
2653    self.assertEqual(1, len(var_dict2))
2654
2655    saver_list2 = graph2.get_collection(ops_lib.GraphKeys.SAVERS)
2656    self.assertEqual(1, len(saver_list2))
2657
2658    with self.test_session(graph=graph2) as sess:
2659      saver_list2[0].restore(sess, saver2_ckpt)
2660      self.assertEqual(2.0, var_dict2["variable2:0"].eval())
2661
2662
2663if __name__ == "__main__":
2664  test.main()
2665