1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15"""Tests for tensorflow.python.training.saver.py.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import contextlib 22import math 23import os 24import random 25import shutil 26import tempfile 27import time 28 29import numpy as np 30import six 31 32from google.protobuf.any_pb2 import Any 33from google.protobuf import text_format 34 35from tensorflow.core.protobuf import config_pb2 36from tensorflow.core.protobuf import meta_graph_pb2 37from tensorflow.core.protobuf import queue_runner_pb2 38from tensorflow.core.protobuf import saver_pb2 39from tensorflow.python import pywrap_tensorflow 40from tensorflow.python.client import session 41from tensorflow.python.data.ops import dataset_ops 42from tensorflow.python.eager import context 43from tensorflow.python.framework import constant_op 44from tensorflow.python.framework import dtypes 45from tensorflow.python.framework import errors 46from tensorflow.python.framework import errors_impl 47from tensorflow.python.framework import function 48from tensorflow.python.framework import graph_io 49from tensorflow.python.framework import meta_graph 50from tensorflow.python.framework import ops as ops_lib 51from tensorflow.python.framework import test_util 52from tensorflow.python.lib.io import file_io 53from tensorflow.python.ops import array_ops 54from tensorflow.python.ops import control_flow_ops 55from tensorflow.python.ops import data_flow_ops 56from tensorflow.python.ops import math_ops 57from tensorflow.python.ops import nn_ops 58from tensorflow.python.ops import partitioned_variables 59from tensorflow.python.ops import random_ops 60from tensorflow.python.ops import resource_variable_ops 61from tensorflow.python.ops import sparse_ops 62from tensorflow.python.ops import variable_scope 63from tensorflow.python.ops import variables 64import tensorflow.python.ops.nn_grad # pylint: disable=unused-import 65from tensorflow.python.platform import gfile 66from tensorflow.python.platform import test 67from tensorflow.python.summary import summary 68from tensorflow.python.training import adam 69from tensorflow.python.training import gradient_descent 70from tensorflow.python.training import queue_runner_impl 71from tensorflow.python.training import saver as saver_module 72from tensorflow.python.training import saver_test_utils 73from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState 74from tensorflow.python.util import compat 75 76 77@test_util.with_c_api 78class SaverTest(test.TestCase): 79 80 def basicSaveRestore(self, variable_op): 81 save_path = os.path.join(self.get_temp_dir(), "basic_save_restore") 82 83 with self.test_session(graph=ops_lib.Graph()) as sess: 84 # Build a graph with 2 parameter nodes, and Save and 85 # Restore nodes for them. 86 v0 = variable_op(10.0, name="v0") 87 v1 = variable_op(20.0, name="v1") 88 v2 = saver_test_utils.CheckpointedOp(name="v2") 89 v2_init = v2.insert("k1", 30.0) 90 91 # Initialize all variables 92 if context.in_graph_mode(): 93 self.evaluate([variables.global_variables_initializer(), v2_init]) 94 95 # Check that the parameter nodes have been initialized. 96 self.assertEqual(10.0, self.evaluate(v0)) 97 self.assertEqual(20.0, self.evaluate(v1)) 98 self.assertEqual(b"k1", self.evaluate(v2.keys())) 99 self.assertEqual(30.0, self.evaluate(v2.values())) 100 101 # Save the initialized values in the file at "save_path" 102 save = saver_module.Saver( 103 { 104 "v0": v0, 105 "v1": v1, 106 "v2": v2.saveable 107 }, restore_sequentially=True) 108 val = save.save(sess, save_path) 109 self.assertTrue(isinstance(val, six.string_types)) 110 self.assertEqual(save_path, val) 111 112 # Start a second session. In that session the parameter nodes 113 # have not been initialized either. 114 with self.test_session(graph=ops_lib.Graph()) as sess: 115 v0 = variable_op(-1.0, name="v0") 116 v1 = variable_op(-1.0, name="v1") 117 v2 = saver_test_utils.CheckpointedOp(name="v2") 118 119 # Assert that the variables are not initialized. 120 if context.in_graph_mode(): 121 self.assertEqual( 122 len(variables.report_uninitialized_variables().eval()), 2) 123 self.assertEqual(0, len(v2.keys().eval())) 124 self.assertEqual(0, len(v2.values().eval())) 125 # Restore the saved values in the parameter nodes. 126 save = saver_module.Saver({"v0": v0, "v1": v1, "v2": v2.saveable}) 127 save.restore(sess, save_path) 128 # Check that the parameter nodes have been restored. 129 self.assertEqual(10.0, self.evaluate(v0)) 130 self.assertEqual(20.0, self.evaluate(v1)) 131 self.assertEqual(b"k1", self.evaluate(v2.keys())) 132 self.assertEqual(30.0, self.evaluate(v2.values())) 133 134 # Build another graph with 2 nodes, initialized 135 # differently, and a Restore node for them. 136 with self.test_session(graph=ops_lib.Graph()) as sess: 137 v0_2 = variable_op(1000.0, name="v0") 138 v1_2 = variable_op(2000.0, name="v1") 139 v2_2 = saver_test_utils.CheckpointedOp(name="v2") 140 v2_init = v2_2.insert("k1000", 3000.0) 141 142 # Check that the parameter nodes have been initialized. 143 if context.in_graph_mode(): 144 init_all_op = [variables.global_variables_initializer(), v2_init] 145 self.evaluate(init_all_op) 146 # TODO(xpan): Why _mutable_hash_table_v2 doesn't create empty 147 # table as it claims in eager mode? 148 self.assertEqual(b"k1000", self.evaluate(v2_2.keys())) 149 self.assertEqual(3000.0, self.evaluate(v2_2.values())) 150 self.assertEqual(1000.0, self.evaluate(v0_2)) 151 self.assertEqual(2000.0, self.evaluate(v1_2)) 152 153 # Restore the values saved earlier in the parameter nodes. 154 save2 = saver_module.Saver({"v0": v0_2, "v1": v1_2, "v2": v2_2.saveable}) 155 save2.restore(sess, save_path) 156 # Check that the parameter nodes have been restored. 157 self.assertEqual(10.0, self.evaluate(v0_2)) 158 self.assertEqual(20.0, self.evaluate(v1_2)) 159 self.assertEqual(b"k1", self.evaluate(v2_2.keys())) 160 self.assertEqual(30.0, self.evaluate(v2_2.values())) 161 162 def testBasic(self): 163 self.basicSaveRestore(variables.Variable) 164 165 @test_util.run_in_graph_and_eager_modes() 166 def testResourceBasic(self): 167 self.basicSaveRestore(resource_variable_ops.ResourceVariable) 168 169 def testResourceVariableReadOpsAddedDeterministically(self): 170 graph_defs = [] 171 num_graphs = 10 172 for _ in range(num_graphs): 173 with ops_lib.Graph().as_default() as g: 174 for i in range(20): 175 resource_variable_ops.ResourceVariable(i, name="var%s" % i) 176 saver_module.Saver() 177 graph_defs.append(g.as_graph_def()) 178 for i in range(num_graphs - 1): 179 self.assertEqual(graph_defs[i], graph_defs[i + 1]) 180 181 def testEagerBasic(self): 182 with context.eager_mode(): 183 ckpt_prefix = os.path.join(self.get_temp_dir(), "ckpt") 184 185 v1 = resource_variable_ops.ResourceVariable(3.14, name="v1") 186 v2 = resource_variable_ops.ResourceVariable([1, 2], name="v2") 187 save = saver_module.Saver([v1, v2]) 188 save.save(None, ckpt_prefix) 189 190 v1.assign(0.0) 191 v2.assign([0, 0]) 192 self.assertNear(0.0, self.evaluate(v1), 1e-5) 193 self.assertAllEqual([0, 0], self.evaluate(v2)) 194 195 save.restore(None, ckpt_prefix) 196 self.assertNear(3.14, self.evaluate(v1), 1e-5) 197 self.assertAllEqual([1, 2], self.evaluate(v2)) 198 199 def testEagerGraphCompatibility(self): 200 # Save from graph mode and restore from eager mode. 201 graph_ckpt_prefix = os.path.join(self.get_temp_dir(), "graph_ckpt") 202 with context.graph_mode(): 203 with self.test_session(graph=ops_lib.Graph()) as sess: 204 # Create a graph model and save the checkpoint. 205 w1 = resource_variable_ops.ResourceVariable(1.0, name="w1") 206 w2 = resource_variable_ops.ResourceVariable(2.0, name="w2") 207 graph_saver = saver_module.Saver([w1, w2]) 208 sess.run(variables.global_variables_initializer()) 209 graph_saver.save(sess, graph_ckpt_prefix) 210 211 with context.eager_mode(): 212 ops_lib._default_graph_stack.reset() # pylint: disable=protected-access 213 ops_lib.reset_default_graph() 214 215 w1 = resource_variable_ops.ResourceVariable(0.0, name="w1") 216 w2 = resource_variable_ops.ResourceVariable(0.0, name="w2") 217 218 graph_saver = saver_module.Saver([w1, w2]) 219 graph_saver.restore(None, graph_ckpt_prefix) 220 221 self.assertAllEqual(self.evaluate(w1), 1.0) 222 self.assertAllEqual(self.evaluate(w2), 2.0) 223 224 # Save from eager mode and restore from graph mode. 225 eager_ckpt_prefix = os.path.join(self.get_temp_dir(), "eager_ckpt") 226 with context.eager_mode(): 227 ops_lib._default_graph_stack.reset() # pylint: disable=protected-access 228 ops_lib.reset_default_graph() 229 230 w3 = resource_variable_ops.ResourceVariable(3.0, name="w3") 231 w4 = resource_variable_ops.ResourceVariable(4.0, name="w4") 232 233 graph_saver = saver_module.Saver([w3, w4]) 234 graph_saver.save(None, eager_ckpt_prefix) 235 236 with context.graph_mode(): 237 with self.test_session(graph=ops_lib.Graph()) as sess: 238 w3 = resource_variable_ops.ResourceVariable(0.0, name="w3") 239 w4 = resource_variable_ops.ResourceVariable(0.0, name="w4") 240 graph_saver = saver_module.Saver([w3, w4]) 241 sess.run(variables.global_variables_initializer()) 242 graph_saver.restore(sess, eager_ckpt_prefix) 243 self.assertAllEqual(w3.eval(), 3.0) 244 self.assertAllEqual(w4.eval(), 4.0) 245 246 @test_util.run_in_graph_and_eager_modes() 247 def testResourceSaveRestoreCachingDevice(self): 248 save_path = os.path.join(self.get_temp_dir(), "resource_cache") 249 with self.test_session(graph=ops_lib.Graph()) as sess: 250 v = resource_variable_ops.ResourceVariable([1], caching_device="/cpu:0", 251 name="v") 252 if context.in_graph_mode(): 253 self.evaluate(variables.global_variables_initializer()) 254 else: 255 sess = None 256 save = saver_module.Saver([v]) 257 save.save(sess, save_path) 258 259 save2 = saver_module.Saver([v]) 260 save2.restore(sess, save_path) 261 self.assertEquals(self.evaluate(v), [1]) 262 263 def testSaveCopyRestoreWithSaveRelativePaths(self): 264 """Save, copy checkpoint dir and restore from copied dir. 265 266 This only works for save_relative_paths=True. 267 """ 268 save_dir1 = os.path.join(self.get_temp_dir(), "save_dir1") 269 os.mkdir(save_dir1) 270 save_path1 = os.path.join(save_dir1, "save_copy_restore") 271 272 # Build a graph with 2 parameter nodes, and Save and 273 # Restore nodes for them. 274 v0 = variables.Variable(10.0, name="v0") 275 v1 = variables.Variable(20.0, name="v1") 276 v2 = saver_test_utils.CheckpointedOp(name="v2") 277 v2_init = v2.insert("k1", 30.0) 278 save = saver_module.Saver( 279 var_list={ 280 "v0": v0, 281 "v1": v1, 282 "v2": v2.saveable}, 283 restore_sequentially=True, 284 save_relative_paths=True) 285 init_all_op = [variables.global_variables_initializer(), v2_init] 286 287 with self.test_session() as sess: 288 # Initialize all variables 289 sess.run(init_all_op) 290 291 # Check that the parameter nodes have been initialized. 292 self.assertEqual(10.0, v0.eval()) 293 self.assertEqual(20.0, v1.eval()) 294 self.assertEqual(b"k1", v2.keys().eval()) 295 self.assertEqual(30.0, v2.values().eval()) 296 297 # Save the initialized values in the file at "save_path" 298 val = save.save(sess, save_path1) 299 self.assertTrue(isinstance(val, six.string_types)) 300 self.assertEqual(save_path1, val) 301 302 self.assertEqual(saver_module.latest_checkpoint(save_dir1), save_path1) 303 save_dir2 = os.path.join(self.get_temp_dir(), "save_dir2") 304 os.renames(save_dir1, save_dir2) 305 save_path2 = os.path.join(save_dir2, "save_copy_restore") 306 self.assertEqual(saver_module.latest_checkpoint(save_dir2), save_path2) 307 308 # Start a second session. In that session the parameter nodes 309 # have not been initialized either. 310 with self.test_session() as sess: 311 v0 = variables.Variable(-1.0, name="v0") 312 v1 = variables.Variable(-1.0, name="v1") 313 v2 = saver_test_utils.CheckpointedOp(name="v2") 314 save = saver_module.Saver({"v0": v0, "v1": v1, "v2": v2.saveable}) 315 316 # Assert that the variables are not initialized. 317 self.assertEqual( 318 len(variables.report_uninitialized_variables().eval()), 2) 319 self.assertEqual(0, len(v2.keys().eval())) 320 self.assertEqual(0, len(v2.values().eval())) 321 322 # Restore the saved values in the parameter nodes. 323 save.restore(sess, save_path2) 324 # Check that the parameter nodes have been restored. 325 self.assertEqual(10.0, v0.eval()) 326 self.assertEqual(20.0, v1.eval()) 327 self.assertEqual(b"k1", v2.keys().eval()) 328 self.assertEqual(30.0, v2.values().eval()) 329 330 def testFilenameTensor(self): 331 v0 = variables.Variable(0, name="v0") 332 filename = b"somerandomfilename" 333 save = saver_module.Saver({"v0": v0}, filename=filename) 334 with self.test_session() as sess: 335 tensor = sess.graph.get_tensor_by_name( 336 save.saver_def.filename_tensor_name) 337 self.assertEqual(sess.run(tensor), filename) 338 339 def testInvalidPath(self): 340 v0 = variables.Variable(0, name="v0") 341 for ver in (saver_pb2.SaverDef.V1, saver_pb2.SaverDef.V2): 342 with self.test_session() as sess: 343 save = saver_module.Saver({"v0": v0}, write_version=ver) 344 with self.assertRaisesRegexp(errors.NotFoundError, 345 "Failed to find any matching files for"): 346 save.restore(sess, "invalid path") 347 348 def testInt64(self): 349 save_path = os.path.join(self.get_temp_dir(), "int64") 350 351 with self.test_session() as sess: 352 # Build a graph with 1 node, and save and restore for them. 353 v = variables.Variable(np.int64(15), name="v") 354 save = saver_module.Saver({"v": v}, restore_sequentially=True) 355 variables.global_variables_initializer().run() 356 357 # Save the initialized values in the file at "save_path" 358 val = save.save(sess, save_path) 359 self.assertTrue(isinstance(val, six.string_types)) 360 self.assertEqual(save_path, val) 361 362 with self.test_session() as sess: 363 v = variables.Variable(np.int64(-1), name="v") 364 save = saver_module.Saver({"v": v}) 365 366 with self.assertRaisesWithPredicateMatch( 367 errors_impl.OpError, lambda e: "uninitialized value v" in e.message): 368 sess.run(v) 369 370 # Restore the saved values in the parameter nodes. 371 save.restore(sess, save_path) 372 # Check that the parameter nodes have been restored. 373 self.assertEqual(np.int64(15), v.eval()) 374 375 def testSomeErrors(self): 376 with ops_lib.Graph().as_default(): 377 v0 = variables.Variable([10.0], name="v0") 378 v1 = variables.Variable([20.0], name="v1") 379 v2 = variables.Variable([20.0], name="v2") 380 v2._set_save_slice_info( 381 variables.Variable.SaveSliceInfo("v1", [1], [0], [1])) 382 383 # By default the name used for "v2" will be "v1" and raise an error. 384 with self.assertRaisesRegexp(ValueError, "same name: v1"): 385 saver_module.Saver([v0, v1, v2]) 386 387 # The names are different and will work. 388 saver_module.Saver({"vee1": v1, "other": [v2]}) 389 390 # Partitioned variables also cause name conflicts. 391 p_v1 = variable_scope.get_variable( 392 "p_v1", 393 shape=[4, 5], 394 partitioner=partitioned_variables.fixed_size_partitioner( 395 num_shards=2)) 396 p_v2 = variable_scope.get_variable( 397 "p_v2", 398 shape=[4, 5], 399 partitioner=partitioned_variables.fixed_size_partitioner( 400 num_shards=2)) 401 p_v2._name = "p_v1" 402 with self.assertRaisesRegexp(ValueError, "same name: p_v1"): 403 saver_module.Saver([p_v1, p_v2]) 404 405 def testSameName(self): 406 with ops_lib.Graph().as_default(): 407 v0 = variables.Variable([10.0], name="v0") 408 v2 = saver_test_utils.CheckpointedOp(name="v2") 409 410 # Saving one variable under two names raises an error. 411 with self.assertRaisesRegexp( 412 ValueError, "The same saveable will be restored with two names: v0"): 413 saver_module.Saver({"v0": v0, "v0too": v0}) 414 415 # Ditto for custom saveables. 416 with self.assertRaisesRegexp( 417 ValueError, "The same saveable will be restored with two names: v2"): 418 saver_module.Saver({"v2": v2.saveable, "v2too": v2.saveable}) 419 420 # Verify non-duplicate names work. 421 saver_module.Saver({"v0": v0, "v2": v2.saveable}) 422 423 def testBasicsWithListOfVariables(self): 424 save_path = os.path.join(self.get_temp_dir(), "basics_with_list") 425 426 with self.test_session(graph=ops_lib.Graph()) as sess: 427 # Build a graph with 2 parameter nodes, and Save and 428 # Restore nodes for them. 429 v0 = variables.Variable(10.0, name="v0") 430 v1 = variables.Variable(20.0, name="v1") 431 v2 = saver_test_utils.CheckpointedOp(name="v2") 432 v2_init = v2.insert("k1", 30.0) 433 save = saver_module.Saver([v0, v1, v2.saveable]) 434 variables.global_variables_initializer().run() 435 v2_init.run() 436 437 # Check that the parameter nodes have been initialized. 438 self.assertEqual(10.0, v0.eval()) 439 self.assertEqual(20.0, v1.eval()) 440 self.assertEqual(b"k1", v2.keys().eval()) 441 self.assertEqual(30.0, v2.values().eval()) 442 443 # Save the initialized values in the file at "save_path" 444 val = save.save(sess, save_path) 445 self.assertTrue(isinstance(val, six.string_types)) 446 self.assertEqual(save_path, val) 447 448 # Start a second session. In that session the variables 449 # have not been initialized either. 450 with self.test_session(graph=ops_lib.Graph()) as sess: 451 v0 = variables.Variable(-1.0, name="v0") 452 v1 = variables.Variable(-1.0, name="v1") 453 v2 = saver_test_utils.CheckpointedOp(name="v2") 454 save = saver_module.Saver([v0, v1, v2.saveable]) 455 456 with self.assertRaisesWithPredicateMatch( 457 errors_impl.OpError, lambda e: "uninitialized value v0" in e.message): 458 sess.run(v0) 459 with self.assertRaisesWithPredicateMatch( 460 errors_impl.OpError, lambda e: "uninitialized value v1" in e.message): 461 sess.run(v1) 462 self.assertEqual(0, len(v2.keys().eval())) 463 self.assertEqual(0, len(v2.values().eval())) 464 465 # Restore the saved values in the parameter nodes. 466 save.restore(sess, save_path) 467 # Check that the parameter nodes have been restored. 468 self.assertEqual(10.0, v0.eval()) 469 self.assertEqual(20.0, v1.eval()) 470 self.assertEqual(b"k1", v2.keys().eval()) 471 self.assertEqual(30.0, v2.values().eval()) 472 473 # Build another graph with 2 nodes, initialized 474 # differently, and a Restore node for them. 475 with self.test_session(graph=ops_lib.Graph()) as sess: 476 v0_2 = variables.Variable(1000.0, name="v0") 477 v1_2 = variables.Variable(2000.0, name="v1") 478 v2_2 = saver_test_utils.CheckpointedOp(name="v2") 479 save2 = saver_module.Saver([v0_2, v1_2, v2_2.saveable]) 480 v2_2.insert("k1000", 3000.0).run() 481 variables.global_variables_initializer().run() 482 483 # Check that the parameter nodes have been initialized. 484 self.assertEqual(1000.0, v0_2.eval()) 485 self.assertEqual(2000.0, v1_2.eval()) 486 self.assertEqual(b"k1000", v2_2.keys().eval()) 487 self.assertEqual(3000.0, v2_2.values().eval()) 488 # Restore the values saved earlier in the parameter nodes. 489 save2.restore(sess, save_path) 490 # Check that the parameter nodes have been restored. 491 self.assertEqual(10.0, v0_2.eval()) 492 self.assertEqual(20.0, v1_2.eval()) 493 self.assertEqual(b"k1", v2_2.keys().eval()) 494 self.assertEqual(30.0, v2_2.values().eval()) 495 496 def _SaveAndLoad(self, var_name, var_value, other_value, save_path): 497 with self.test_session(graph=ops_lib.Graph()) as sess: 498 var = resource_variable_ops.ResourceVariable(var_value, name=var_name) 499 save = saver_module.Saver({var_name: var}) 500 if context.in_graph_mode(): 501 self.evaluate(var.initializer) 502 val = save.save(sess, save_path) 503 self.assertEqual(save_path, val) 504 with self.test_session(graph=ops_lib.Graph()) as sess: 505 var = resource_variable_ops.ResourceVariable(other_value, name=var_name) 506 save = saver_module.Saver({var_name: var}) 507 save.restore(sess, save_path) 508 self.assertAllClose(var_value, self.evaluate(var)) 509 510 def testCacheRereadsFile(self): 511 save_path = os.path.join(self.get_temp_dir(), "cache_rereads") 512 # Save and reload one Variable named "var0". 513 self._SaveAndLoad("var0", 0.0, 1.0, save_path) 514 # Save and reload one Variable named "var1" in the same file. 515 # The cached readers should know to re-read the file. 516 self._SaveAndLoad("var1", 1.1, 2.2, save_path) 517 518 def testAllowEmpty(self): 519 save_path = os.path.join(self.get_temp_dir(), "allow_empty") 520 with self.test_session() as sess: 521 _ = constant_op.constant(1) 522 save = saver_module.Saver(allow_empty=True) 523 val = save.save(sess, save_path) 524 self.assertIsNone(val) 525 with self.test_session() as sess: 526 save = saver_module.Saver(allow_empty=True) 527 save.restore(sess, save_path) 528 529 def testGPU(self): 530 if not test.is_gpu_available(): 531 return 532 save_path = os.path.join(self.get_temp_dir(), "gpu") 533 with session.Session("", graph=ops_lib.Graph()) as sess: 534 with sess.graph.device(test.gpu_device_name()): 535 v0_1 = variables.Variable(123.45) 536 save = saver_module.Saver({"v0": v0_1}) 537 variables.global_variables_initializer().run() 538 save.save(sess, save_path) 539 540 with session.Session("", graph=ops_lib.Graph()) as sess: 541 with sess.graph.device(test.gpu_device_name()): 542 v0_2 = variables.Variable(543.21) 543 save = saver_module.Saver({"v0": v0_2}) 544 variables.global_variables_initializer().run() 545 546 def testSharedServerOnGPU(self): 547 if not test.is_gpu_available(): 548 return 549 save_path = os.path.join(self.get_temp_dir(), "gpu") 550 with session.Session("", graph=ops_lib.Graph()) as sess: 551 with sess.graph.device(test.gpu_device_name()): 552 v0_1 = variables.Variable(123.45) 553 save = saver_module.Saver({"v0": v0_1}, sharded=True, allow_empty=True) 554 variables.global_variables_initializer().run() 555 save.save(sess, save_path) 556 557 with session.Session("", graph=ops_lib.Graph()) as sess: 558 with sess.graph.device(test.gpu_device_name()): 559 v0_2 = variables.Variable(543.21) 560 save = saver_module.Saver({"v0": v0_2}, sharded=True, allow_empty=True) 561 variables.global_variables_initializer().run() 562 563 def testVariables(self): 564 save_path = os.path.join(self.get_temp_dir(), "variables") 565 with session.Session("", graph=ops_lib.Graph()) as sess: 566 one = variables.Variable(1.0) 567 twos = variables.Variable([2.0, 2.0, 2.0]) 568 v2 = saver_test_utils.CheckpointedOp(name="v2") 569 init = variables.global_variables_initializer() 570 save = saver_module.Saver() 571 init.run() 572 v2.insert("k1", 3.0).run() 573 save.save(sess, save_path) 574 575 with session.Session("", graph=ops_lib.Graph()) as sess: 576 one = variables.Variable(0.0) 577 twos = variables.Variable([0.0, 0.0, 0.0]) 578 v2 = saver_test_utils.CheckpointedOp(name="v2") 579 # Saver with no arg, defaults to 'all variables'. 580 save = saver_module.Saver() 581 save.restore(sess, save_path) 582 self.assertAllClose(1.0, one.eval()) 583 self.assertAllClose([2.0, 2.0, 2.0], twos.eval()) 584 self.assertEqual(b"k1", v2.keys().eval()) 585 self.assertEqual(3.0, v2.values().eval()) 586 587 def testVarListShouldBeEmptyInDeferredBuild(self): 588 with ops_lib.Graph().as_default(): 589 v = variables.Variable(1.0) 590 with self.assertRaisesRegexp(ValueError, "defer_build"): 591 saver_module.Saver([v], defer_build=True) 592 593 def testBuildShouldBeCalledBeforeSaveInCaseOfDeferBuild(self): 594 save_path = os.path.join(self.get_temp_dir(), "error_deferred_build") 595 with ops_lib.Graph().as_default(), session.Session() as sess: 596 variables.Variable(1.0) 597 saver = saver_module.Saver(defer_build=True) 598 with self.assertRaisesRegexp(RuntimeError, "build"): 599 saver.save(sess, save_path) 600 601 def testDeferredBuild(self): 602 save_path = os.path.join(self.get_temp_dir(), "deferred_build") 603 with session.Session("", graph=ops_lib.Graph()) as sess: 604 one = variables.Variable(1.0) 605 save = saver_module.Saver(defer_build=True) 606 # if build is not deferred, saver cannot save the `twos`. 607 twos = variables.Variable([2.0, 2.0, 2.0]) 608 init = variables.global_variables_initializer() 609 save.build() 610 init.run() 611 save.save(sess, save_path) 612 613 with session.Session("", graph=ops_lib.Graph()) as sess: 614 one = variables.Variable(0.0) 615 twos = variables.Variable([0.0, 0.0, 0.0]) 616 # Saver with no arg, defaults to 'all variables'. 617 save = saver_module.Saver() 618 save.restore(sess, save_path) 619 self.assertAllClose(1.0, one.eval()) 620 self.assertAllClose([2.0, 2.0, 2.0], twos.eval()) 621 622 def testReshape(self): 623 save_path = os.path.join(self.get_temp_dir(), "variables_reshape") 624 with session.Session("", graph=ops_lib.Graph()) as sess: 625 var = variables.Variable([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 626 init = variables.global_variables_initializer() 627 save = saver_module.Saver() 628 init.run() 629 save.save(sess, save_path) 630 631 # Error when restoring with default reshape=False 632 with session.Session("", graph=ops_lib.Graph()) as sess: 633 var = variables.Variable([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]]) 634 save = saver_module.Saver() 635 with self.assertRaisesRegexp( 636 errors_impl.InvalidArgumentError, 637 "Assign requires shapes of both tensors to match."): 638 save.restore(sess, save_path) 639 640 # Restored to new shape with reshape=True 641 with session.Session("", graph=ops_lib.Graph()) as sess: 642 var = variables.Variable([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]]) 643 save = saver_module.Saver(reshape=True) 644 save.restore(sess, save_path) 645 self.assertAllClose([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], var.eval()) 646 647 @test_util.run_in_graph_and_eager_modes() 648 def testSaveWithGlobalStep(self, pad_step_number=False): 649 save_path = os.path.join(self.get_temp_dir(), "ckpt_with_global_step") 650 global_step_int = 5 651 # Save and reload one Variable named "var0". 652 self._SaveAndLoad("var0", 0.0, 1.0, save_path) 653 for use_tensor in [True, False]: 654 with self.test_session(graph=ops_lib.Graph()): 655 var = resource_variable_ops.ResourceVariable(1.0, name="var0") 656 save = saver_module.Saver( 657 { 658 var._shared_name: var 659 }, pad_step_number=pad_step_number) 660 if context.in_graph_mode(): 661 self.evaluate(var.initializer) 662 sess = ops_lib.get_default_session() 663 else: 664 sess = None 665 if use_tensor: 666 global_step = constant_op.constant(global_step_int) 667 val = save.save(sess, save_path, global_step=global_step) 668 else: 669 val = save.save(sess, save_path, global_step=global_step_int) 670 if pad_step_number: 671 expected_save_path = "%s-%s" % (save_path, 672 "{:08d}".format(global_step_int)) 673 else: 674 expected_save_path = "%s-%d" % (save_path, global_step_int) 675 self.assertEqual(expected_save_path, val) 676 677 def testSaveWithGlobalStepWithPadding(self): 678 self.testSaveWithGlobalStep(pad_step_number=True) 679 680 def testSaveToNonexistingPath(self): 681 file_io.write_string_to_file( 682 os.path.join(self.get_temp_dir(), "actually_a_file"), "") 683 paths = [ 684 os.path.join(self.get_temp_dir(), "nonexisting_dir/path"), 685 os.path.join(self.get_temp_dir(), "other_nonexisting_dir/path1/path2"), 686 os.path.join(self.get_temp_dir(), "actually_a_file/path"), 687 ] 688 689 for save_path in paths: 690 # Build a graph with 2 parameter nodes, and Save and 691 # Restore nodes for them. 692 v0 = variables.Variable(10.0, name="v0") 693 v1 = variables.Variable(20.0, name="v1") 694 save = saver_module.Saver({"v0": v0, "v1": v1}, restore_sequentially=True) 695 init_all_op = variables.global_variables_initializer() 696 697 # In the case where the parent directory doesn't exist, whether or not the 698 # save succeeds or fails is implementation dependent. Therefore we allow 699 # both cases. 700 try: 701 with self.test_session() as sess: 702 # Initialize all variables 703 sess.run(init_all_op) 704 705 # Check that the parameter nodes have been initialized. 706 self.assertEqual(10.0, v0.eval()) 707 self.assertEqual(20.0, v1.eval()) 708 709 # Save the graph. 710 save.save(sess, save_path) 711 712 with self.test_session() as sess: 713 # Restore the saved values in the parameter nodes. 714 save.restore(sess, save_path) 715 # Check that the parameter nodes have been restored. 716 self.assertEqual(10.0, v0.eval()) 717 self.assertEqual(20.0, v1.eval()) 718 except ValueError as exc: 719 error_msg_template = "Parent directory of {} doesn't exist, can't save." 720 self.assertEqual(error_msg_template.format(save_path), str(exc)) 721 722 def testSaveToURI(self): 723 # ParseURI functions don't work on Windows yet. 724 # TODO(jhseu): Remove this check when it works. 725 if os.name == "nt": 726 self.skipTest("Local URI support doesn't work on Windows") 727 save_path = "file://" + os.path.join(self.get_temp_dir(), "uri") 728 729 # Build a graph with 2 parameter nodes, and Save and 730 # Restore nodes for them. 731 v0 = variables.Variable(10.0, name="v0") 732 v1 = variables.Variable(20.0, name="v1") 733 save = saver_module.Saver({"v0": v0, "v1": v1}, restore_sequentially=True) 734 init_all_op = variables.global_variables_initializer() 735 736 with self.test_session() as sess: 737 # Initialize all variables 738 sess.run(init_all_op) 739 740 # Check that the parameter nodes have been initialized. 741 self.assertEqual(10.0, v0.eval()) 742 self.assertEqual(20.0, v1.eval()) 743 save.save(sess, save_path) 744 745 746@test_util.with_c_api 747class SaveRestoreShardedTest(test.TestCase): 748 749 _WRITE_VERSION = saver_pb2.SaverDef.V1 750 751 def _get_test_dir(self, dirname): 752 test_dir = os.path.join(self.get_temp_dir(), dirname) 753 gfile.MakeDirs(test_dir) 754 return test_dir 755 756 def testBasics(self): 757 save_path = os.path.join(self.get_temp_dir(), "sharded_basics") 758 759 # Build a graph with 2 parameter nodes on different devices. 760 with session.Session( 761 target="", 762 config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess: 763 with sess.graph.device("/cpu:0"): 764 v0 = variables.Variable(10, name="v0") 765 t0 = saver_test_utils.CheckpointedOp(name="t0") 766 with sess.graph.device("/cpu:1"): 767 v1 = variables.Variable(20, name="v1") 768 t1 = saver_test_utils.CheckpointedOp(name="t1") 769 save = saver_module.Saver( 770 { 771 "v0": v0, 772 "v1": v1, 773 "t0": t0.saveable, 774 "t1": t1.saveable 775 }, 776 write_version=self._WRITE_VERSION, 777 sharded=True) 778 variables.global_variables_initializer().run() 779 t0.insert("k1", 30.0).run() 780 t1.insert("k2", 40.0).run() 781 val = save.save(sess, save_path) 782 if save._write_version is saver_pb2.SaverDef.V1: 783 self.assertEqual(save_path + "-?????-of-00002", val) 784 else: 785 self.assertEqual(save_path, val) 786 meta_graph_filename = save._MetaGraphFilename(val) 787 self.assertEqual(save_path + ".meta", meta_graph_filename) 788 789 if save._write_version is saver_pb2.SaverDef.V1: 790 # Restore different ops from shard 0 of the saved files. 791 with session.Session( 792 target="", 793 config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess: 794 with sess.graph.device("/cpu:0"): 795 v0 = variables.Variable(111, name="v0") 796 t0 = saver_test_utils.CheckpointedOp(name="t0") 797 save = saver_module.Saver( 798 { 799 "v0": v0, 800 "t0": t0.saveable 801 }, 802 write_version=self._WRITE_VERSION, 803 sharded=True) 804 variables.global_variables_initializer().run() 805 t0.insert("k11", 33.0).run() 806 self.assertEqual(111, v0.eval()) 807 self.assertEqual(b"k11", t0.keys().eval()) 808 self.assertEqual(33.0, t0.values().eval()) 809 save.restore(sess, save_path + "-00000-of-00002") 810 self.assertEqual(10, v0.eval()) 811 self.assertEqual(b"k1", t0.keys().eval()) 812 self.assertEqual(30.0, t0.values().eval()) 813 814 # Restore different ops from shard 1 of the saved files. 815 with session.Session( 816 target="", 817 config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess: 818 with sess.graph.device("/cpu:0"): 819 v1 = variables.Variable(222) 820 t1 = saver_test_utils.CheckpointedOp(name="t1") 821 save = saver_module.Saver( 822 { 823 "v1": v1, 824 "t1": t1.saveable 825 }, 826 write_version=self._WRITE_VERSION, 827 sharded=True) 828 variables.global_variables_initializer().run() 829 t1.insert("k22", 44.0).run() 830 self.assertEqual(222, v1.eval()) 831 self.assertEqual(b"k22", t1.keys().eval()) 832 self.assertEqual(44.0, t1.values().eval()) 833 save.restore(sess, save_path + "-00001-of-00002") 834 self.assertEqual(20, v1.eval()) 835 self.assertEqual(b"k2", t1.keys().eval()) 836 self.assertEqual(40.0, t1.values().eval()) 837 838 # Now try a restore with the sharded filename. 839 with session.Session( 840 target="", 841 config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess: 842 with sess.graph.device("/cpu:0"): 843 v0 = variables.Variable(111, name="v0") 844 t0 = saver_test_utils.CheckpointedOp(name="t0") 845 with sess.graph.device("/cpu:1"): 846 v1 = variables.Variable(222, name="v1") 847 t1 = saver_test_utils.CheckpointedOp(name="t1") 848 save = saver_module.Saver( 849 { 850 "v0": v0, 851 "v1": v1, 852 "t0": t0.saveable, 853 "t1": t1.saveable 854 }, 855 write_version=self._WRITE_VERSION, 856 sharded=True) 857 variables.global_variables_initializer().run() 858 t0.insert("k11", 33.0).run() 859 t1.insert("k22", 44.0).run() 860 self.assertEqual(111, v0.eval()) 861 self.assertEqual(222, v1.eval()) 862 self.assertEqual(b"k11", t0.keys().eval()) 863 self.assertEqual(33.0, t0.values().eval()) 864 self.assertEqual(b"k22", t1.keys().eval()) 865 self.assertEqual(44.0, t1.values().eval()) 866 save_path = os.path.join(self.get_temp_dir(), "sharded_basics") 867 if save._write_version is saver_pb2.SaverDef.V1: 868 save.restore(sess, save_path + "-?????-of-?????") 869 else: 870 save.restore(sess, save_path) 871 self.assertEqual(10, v0.eval()) 872 self.assertEqual(20, v1.eval()) 873 self.assertEqual(b"k1", t0.keys().eval()) 874 self.assertEqual(30.0, t0.values().eval()) 875 self.assertEqual(b"k2", t1.keys().eval()) 876 self.assertEqual(40.0, t1.values().eval()) 877 878 if save._write_version is saver_pb2.SaverDef.V1: 879 self.assertEqual( 880 saver_module.latest_checkpoint(self.get_temp_dir()), 881 os.path.join(self.get_temp_dir(), "sharded_basics-?????-of-00002")) 882 else: 883 self.assertEqual( 884 saver_module.latest_checkpoint(self.get_temp_dir()), 885 os.path.join(self.get_temp_dir(), "sharded_basics")) 886 887 def testSaverDef(self): 888 with self.test_session(): 889 v0 = variables.Variable(123, name="v0") 890 save = saver_module.Saver({"v0": v0}, sharded=True) 891 sd = save.as_saver_def() 892 self.assertTrue(sd.sharded) 893 894 def _testPartitionedVariables(self, use_resource): 895 var_full_shape = [10, 3] 896 # Allows save/restore mechanism to work w/ different slicings. 897 var_name = "my_var" 898 saved_dir = self._get_test_dir("partitioned_variables") 899 saved_path = os.path.join(saved_dir, "ckpt") 900 901 call_saver_with_dict = False # updated by test loop below 902 903 def _save(slices=None, partitioner=None): 904 with self.test_session(graph=ops_lib.Graph()) as sess: 905 # Calls .eval() to return the ndarray that makes up the full variable. 906 rnd = random_ops.random_uniform(var_full_shape).eval() 907 908 if slices: 909 assert not partitioner 910 # TODO(apassos): make create_partitioned_variables take use_resource 911 # option to make this test passable without creating a named 912 # variable_scope. 913 vs = partitioned_variables.create_partitioned_variables( 914 var_full_shape, slices, rnd, name=var_name) 915 elif partitioner: 916 vs = [ 917 variable_scope.get_variable( 918 var_name, 919 shape=var_full_shape, 920 initializer=rnd, 921 partitioner=partitioner, 922 use_resource=use_resource) 923 ] 924 else: 925 if use_resource: 926 vs = [resource_variable_ops.ResourceVariable(rnd, name=var_name)] 927 else: 928 vs = [variables.Variable(rnd, name=var_name)] 929 930 variables.global_variables_initializer().run() 931 if call_saver_with_dict: 932 saver = saver_module.Saver({var_name: (vs if slices else vs[0])}) 933 else: 934 saver = saver_module.Saver(vs) 935 actual_path = saver.save(sess, saved_path) 936 self.assertEqual(saved_path, actual_path) 937 938 return rnd 939 940 def _restore(slices=None, partitioner=None): 941 with self.test_session(graph=ops_lib.Graph()) as sess: 942 if slices: 943 assert not partitioner 944 new_vs = partitioned_variables.create_partitioned_variables( 945 var_full_shape, 946 slices, 947 array_ops.zeros(var_full_shape), # != original contents. 948 name=var_name) 949 elif partitioner: 950 new_vs = [ 951 variable_scope.get_variable( 952 var_name, 953 shape=var_full_shape, 954 initializer=array_ops.zeros(var_full_shape), 955 partitioner=partitioner) 956 ] 957 else: 958 new_vs = [ 959 variables.Variable( 960 array_ops.zeros( 961 shape=var_full_shape), # != original contents. 962 name=var_name) 963 ] 964 965 variables.global_variables_initializer().run() 966 if call_saver_with_dict: 967 saver = saver_module.Saver({ 968 var_name: (new_vs if slices else new_vs[0]) 969 }) 970 else: 971 saver = saver_module.Saver(new_vs) 972 saver.restore(sess, saved_path) 973 974 if partitioner: 975 return new_vs[0].as_tensor().eval() 976 elif slices and slices[0] != 1: 977 return array_ops.concat(new_vs, 0).eval() 978 elif slices and slices[1] != 1: 979 return array_ops.concat(new_vs, 1).eval() 980 else: # Non-sliced. 981 return new_vs[0].eval() 982 983 for call_saver_with_dict in {False, True}: 984 # Save PartitionedVariable and restore into full variable. 985 saved_full = _save( 986 partitioner=partitioned_variables.fixed_size_partitioner( 987 num_shards=2)) 988 restored_full = _restore() 989 self.assertAllEqual(saved_full, restored_full) 990 991 # Saves 10 horizontal parts of a partitioned variable. 992 # Restores into a full variable, non-sliced. 993 saved_full = _save(slices=[10, 1]) 994 restored_full = _restore() 995 self.assertAllEqual(saved_full, restored_full) 996 997 # Restores into a different number/orientation of slices. 998 restored_full = _restore(slices=[2, 1]) # 2 horizon parts. 999 self.assertAllEqual(saved_full, restored_full) 1000 restored_full = _restore(slices=[1, 3]) # 3 vertical parts. 1001 self.assertAllEqual(saved_full, restored_full) 1002 1003 # Restores into a PartitionedVariable 1004 restored_full = _restore( 1005 partitioner=partitioned_variables.fixed_size_partitioner( 1006 num_shards=2)) 1007 self.assertAllEqual(saved_full, restored_full) 1008 1009 # Now, saves a full variable and restores in slices. 1010 saved_full = _save() 1011 restored_full = _restore(slices=[1, 3]) 1012 self.assertAllEqual(saved_full, restored_full) 1013 1014 def testPartitionedVariable(self): 1015 self._testPartitionedVariables(use_resource=False) 1016 1017 def testPartitionedResourceVariable(self): 1018 self._testPartitionedVariables(use_resource=True) 1019 1020 1021@test_util.with_c_api 1022class SaveRestoreShardedTestV2(SaveRestoreShardedTest): 1023 _WRITE_VERSION = saver_pb2.SaverDef.V2 1024 1025 1026@test_util.with_c_api 1027class MaxToKeepTest(test.TestCase): 1028 1029 def _get_test_dir(self, dirname): 1030 test_dir = os.path.join(self.get_temp_dir(), dirname) 1031 gfile.MakeDirs(test_dir) 1032 return test_dir 1033 1034 def assertCheckpointState(self, model_checkpoint_path, 1035 all_model_checkpoint_paths, save_dir): 1036 checkpoint_state = saver_module.get_checkpoint_state(save_dir) 1037 self.assertEqual(checkpoint_state.model_checkpoint_path, 1038 model_checkpoint_path) 1039 self.assertEqual(checkpoint_state.all_model_checkpoint_paths, 1040 all_model_checkpoint_paths) 1041 1042 def testNonSharded(self): 1043 save_dir = self._get_test_dir("max_to_keep_non_sharded") 1044 1045 with self.test_session() as sess: 1046 v = variables.Variable(10.0, name="v") 1047 save = saver_module.Saver({"v": v}, max_to_keep=2) 1048 variables.global_variables_initializer().run() 1049 self.assertEqual([], save.last_checkpoints) 1050 1051 s1 = save.save(sess, os.path.join(save_dir, "s1")) 1052 self.assertEqual([s1], save.last_checkpoints) 1053 self.assertTrue(saver_module.checkpoint_exists(s1)) 1054 self.assertCheckpointState( 1055 model_checkpoint_path=s1, 1056 all_model_checkpoint_paths=[s1], 1057 save_dir=save_dir) 1058 1059 s2 = save.save(sess, os.path.join(save_dir, "s2")) 1060 self.assertEqual([s1, s2], save.last_checkpoints) 1061 self.assertTrue(saver_module.checkpoint_exists(s1)) 1062 self.assertTrue(saver_module.checkpoint_exists(s2)) 1063 self.assertCheckpointState( 1064 model_checkpoint_path=s2, 1065 all_model_checkpoint_paths=[s1, s2], 1066 save_dir=save_dir) 1067 1068 s3 = save.save(sess, os.path.join(save_dir, "s3")) 1069 self.assertEqual([s2, s3], save.last_checkpoints) 1070 self.assertFalse(saver_module.checkpoint_exists(s1)) 1071 self.assertTrue(saver_module.checkpoint_exists(s2)) 1072 self.assertTrue(saver_module.checkpoint_exists(s3)) 1073 self.assertCheckpointState( 1074 model_checkpoint_path=s3, 1075 all_model_checkpoint_paths=[s2, s3], 1076 save_dir=save_dir) 1077 1078 # Create a second helper, identical to the first. 1079 save2 = saver_module.Saver(saver_def=save.as_saver_def()) 1080 save2.set_last_checkpoints(save.last_checkpoints) 1081 1082 # Create a third helper, with the same configuration but no knowledge of 1083 # previous checkpoints. 1084 save3 = saver_module.Saver(saver_def=save.as_saver_def()) 1085 1086 # Exercise the first helper. 1087 1088 # Adding s2 again (old s2 is removed first, then new s2 appended) 1089 s2 = save.save(sess, os.path.join(save_dir, "s2")) 1090 self.assertEqual([s3, s2], save.last_checkpoints) 1091 self.assertFalse(saver_module.checkpoint_exists(s1)) 1092 self.assertFalse( 1093 saver_module.checkpoint_exists(save._MetaGraphFilename(s1))) 1094 self.assertTrue(saver_module.checkpoint_exists(s3)) 1095 self.assertTrue( 1096 saver_module.checkpoint_exists(save._MetaGraphFilename(s3))) 1097 self.assertTrue(saver_module.checkpoint_exists(s2)) 1098 self.assertTrue( 1099 saver_module.checkpoint_exists(save._MetaGraphFilename(s2))) 1100 self.assertCheckpointState( 1101 model_checkpoint_path=s2, 1102 all_model_checkpoint_paths=[s3, s2], 1103 save_dir=save_dir) 1104 1105 # Adding s1 (s3 should now be deleted as oldest in list) 1106 s1 = save.save(sess, os.path.join(save_dir, "s1")) 1107 self.assertEqual([s2, s1], save.last_checkpoints) 1108 self.assertFalse(saver_module.checkpoint_exists(s3)) 1109 self.assertFalse( 1110 saver_module.checkpoint_exists(save._MetaGraphFilename(s3))) 1111 self.assertTrue(saver_module.checkpoint_exists(s2)) 1112 self.assertTrue( 1113 saver_module.checkpoint_exists(save._MetaGraphFilename(s2))) 1114 self.assertTrue(saver_module.checkpoint_exists(s1)) 1115 self.assertTrue( 1116 saver_module.checkpoint_exists(save._MetaGraphFilename(s1))) 1117 self.assertCheckpointState( 1118 model_checkpoint_path=s1, 1119 all_model_checkpoint_paths=[s2, s1], 1120 save_dir=save_dir) 1121 1122 # Exercise the second helper. 1123 1124 # Adding s2 again (old s2 is removed first, then new s2 appended) 1125 s2 = save2.save(sess, os.path.join(save_dir, "s2")) 1126 self.assertEqual([s3, s2], save2.last_checkpoints) 1127 # Created by the first helper. 1128 self.assertTrue(saver_module.checkpoint_exists(s1)) 1129 self.assertTrue( 1130 saver_module.checkpoint_exists(save._MetaGraphFilename(s1))) 1131 # Deleted by the first helper. 1132 self.assertFalse(saver_module.checkpoint_exists(s3)) 1133 self.assertFalse( 1134 saver_module.checkpoint_exists(save._MetaGraphFilename(s3))) 1135 self.assertTrue(saver_module.checkpoint_exists(s2)) 1136 self.assertTrue( 1137 saver_module.checkpoint_exists(save._MetaGraphFilename(s2))) 1138 self.assertCheckpointState( 1139 model_checkpoint_path=s2, 1140 all_model_checkpoint_paths=[s3, s2], 1141 save_dir=save_dir) 1142 1143 # Adding s1 (s3 should now be deleted as oldest in list) 1144 s1 = save2.save(sess, os.path.join(save_dir, "s1")) 1145 self.assertEqual([s2, s1], save2.last_checkpoints) 1146 self.assertFalse(saver_module.checkpoint_exists(s3)) 1147 self.assertFalse( 1148 saver_module.checkpoint_exists(save._MetaGraphFilename(s3))) 1149 self.assertTrue(saver_module.checkpoint_exists(s2)) 1150 self.assertTrue( 1151 saver_module.checkpoint_exists(save._MetaGraphFilename(s2))) 1152 self.assertTrue(saver_module.checkpoint_exists(s1)) 1153 self.assertTrue( 1154 saver_module.checkpoint_exists(save._MetaGraphFilename(s1))) 1155 self.assertCheckpointState( 1156 model_checkpoint_path=s1, 1157 all_model_checkpoint_paths=[s2, s1], 1158 save_dir=save_dir) 1159 1160 # Exercise the third helper. 1161 1162 # Adding s2 again (but helper is unaware of previous s2) 1163 s2 = save3.save(sess, os.path.join(save_dir, "s2")) 1164 self.assertEqual([s2], save3.last_checkpoints) 1165 # Created by the first helper. 1166 self.assertTrue(saver_module.checkpoint_exists(s1)) 1167 self.assertTrue( 1168 saver_module.checkpoint_exists(save._MetaGraphFilename(s1))) 1169 # Deleted by the first helper. 1170 self.assertFalse(saver_module.checkpoint_exists(s3)) 1171 self.assertFalse( 1172 saver_module.checkpoint_exists(save._MetaGraphFilename(s3))) 1173 self.assertTrue(saver_module.checkpoint_exists(s2)) 1174 self.assertTrue( 1175 saver_module.checkpoint_exists(save._MetaGraphFilename(s2))) 1176 # Even though the file for s1 exists, this saver isn't aware of it, which 1177 # is why it doesn't end up in the checkpoint state. 1178 self.assertCheckpointState( 1179 model_checkpoint_path=s2, 1180 all_model_checkpoint_paths=[s2], 1181 save_dir=save_dir) 1182 1183 # Adding s1 (s3 should not be deleted because helper is unaware of it) 1184 s1 = save3.save(sess, os.path.join(save_dir, "s1")) 1185 self.assertEqual([s2, s1], save3.last_checkpoints) 1186 self.assertFalse(saver_module.checkpoint_exists(s3)) 1187 self.assertFalse( 1188 saver_module.checkpoint_exists(save._MetaGraphFilename(s3))) 1189 self.assertTrue(saver_module.checkpoint_exists(s2)) 1190 self.assertTrue( 1191 saver_module.checkpoint_exists(save._MetaGraphFilename(s2))) 1192 self.assertTrue(saver_module.checkpoint_exists(s1)) 1193 self.assertTrue( 1194 saver_module.checkpoint_exists(save._MetaGraphFilename(s1))) 1195 self.assertCheckpointState( 1196 model_checkpoint_path=s1, 1197 all_model_checkpoint_paths=[s2, s1], 1198 save_dir=save_dir) 1199 1200 def testSharded(self): 1201 save_dir = self._get_test_dir("max_to_keep_sharded") 1202 1203 with session.Session( 1204 target="", 1205 config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess: 1206 with sess.graph.device("/cpu:0"): 1207 v0 = variables.Variable(111, name="v0") 1208 with sess.graph.device("/cpu:1"): 1209 v1 = variables.Variable(222, name="v1") 1210 save = saver_module.Saver( 1211 { 1212 "v0": v0, 1213 "v1": v1 1214 }, sharded=True, max_to_keep=2) 1215 variables.global_variables_initializer().run() 1216 self.assertEqual([], save.last_checkpoints) 1217 1218 s1 = save.save(sess, os.path.join(save_dir, "s1")) 1219 self.assertEqual([s1], save.last_checkpoints) 1220 if save._write_version is saver_pb2.SaverDef.V1: 1221 self.assertEqual(2, len(gfile.Glob(s1))) 1222 else: 1223 self.assertEqual(4, len(gfile.Glob(s1 + "*"))) 1224 1225 self.assertTrue(gfile.Exists(save._MetaGraphFilename(s1))) 1226 1227 s2 = save.save(sess, os.path.join(save_dir, "s2")) 1228 self.assertEqual([s1, s2], save.last_checkpoints) 1229 if save._write_version is saver_pb2.SaverDef.V1: 1230 self.assertEqual(2, len(gfile.Glob(s1))) 1231 else: 1232 self.assertEqual(4, len(gfile.Glob(s1 + "*"))) 1233 self.assertTrue(gfile.Exists(save._MetaGraphFilename(s1))) 1234 if save._write_version is saver_pb2.SaverDef.V1: 1235 self.assertEqual(2, len(gfile.Glob(s2))) 1236 else: 1237 self.assertEqual(4, len(gfile.Glob(s2 + "*"))) 1238 self.assertTrue(gfile.Exists(save._MetaGraphFilename(s2))) 1239 1240 s3 = save.save(sess, os.path.join(save_dir, "s3")) 1241 self.assertEqual([s2, s3], save.last_checkpoints) 1242 self.assertEqual(0, len(gfile.Glob(s1 + "*"))) 1243 self.assertFalse(gfile.Exists(save._MetaGraphFilename(s1))) 1244 if save._write_version is saver_pb2.SaverDef.V1: 1245 self.assertEqual(2, len(gfile.Glob(s2))) 1246 else: 1247 self.assertEqual(4, len(gfile.Glob(s2 + "*"))) 1248 self.assertTrue(gfile.Exists(save._MetaGraphFilename(s2))) 1249 if save._write_version is saver_pb2.SaverDef.V1: 1250 self.assertEqual(2, len(gfile.Glob(s3))) 1251 else: 1252 self.assertEqual(4, len(gfile.Glob(s3 + "*"))) 1253 self.assertTrue(gfile.Exists(save._MetaGraphFilename(s3))) 1254 1255 def testNoMaxToKeep(self): 1256 save_dir = self._get_test_dir("no_max_to_keep") 1257 save_dir2 = self._get_test_dir("max_to_keep_0") 1258 1259 with self.test_session() as sess: 1260 v = variables.Variable(10.0, name="v") 1261 variables.global_variables_initializer().run() 1262 1263 # Test max_to_keep being None. 1264 save = saver_module.Saver({"v": v}, max_to_keep=None) 1265 self.assertEqual([], save.last_checkpoints) 1266 s1 = save.save(sess, os.path.join(save_dir, "s1")) 1267 self.assertEqual([], save.last_checkpoints) 1268 self.assertTrue(saver_module.checkpoint_exists(s1)) 1269 s2 = save.save(sess, os.path.join(save_dir, "s2")) 1270 self.assertEqual([], save.last_checkpoints) 1271 self.assertTrue(saver_module.checkpoint_exists(s2)) 1272 1273 # Test max_to_keep being 0. 1274 save2 = saver_module.Saver({"v": v}, max_to_keep=0) 1275 self.assertEqual([], save2.last_checkpoints) 1276 s1 = save2.save(sess, os.path.join(save_dir2, "s1")) 1277 self.assertEqual([], save2.last_checkpoints) 1278 self.assertTrue(saver_module.checkpoint_exists(s1)) 1279 s2 = save2.save(sess, os.path.join(save_dir2, "s2")) 1280 self.assertEqual([], save2.last_checkpoints) 1281 self.assertTrue(saver_module.checkpoint_exists(s2)) 1282 1283 def testNoMetaGraph(self): 1284 save_dir = self._get_test_dir("no_meta_graph") 1285 1286 with self.test_session() as sess: 1287 v = variables.Variable(10.0, name="v") 1288 save = saver_module.Saver({"v": v}) 1289 variables.global_variables_initializer().run() 1290 1291 s1 = save.save(sess, os.path.join(save_dir, "s1"), write_meta_graph=False) 1292 self.assertTrue(saver_module.checkpoint_exists(s1)) 1293 self.assertFalse(gfile.Exists(save._MetaGraphFilename(s1))) 1294 1295 1296@test_util.with_c_api 1297class KeepCheckpointEveryNHoursTest(test.TestCase): 1298 1299 def _get_test_dir(self, dirname): 1300 test_dir = os.path.join(self.get_temp_dir(), dirname) 1301 gfile.MakeDirs(test_dir) 1302 return test_dir 1303 1304 @test.mock.patch.object(saver_module, "time") 1305 def testNonSharded(self, mock_time): 1306 save_dir = self._get_test_dir("keep_checkpoint_every_n_hours") 1307 1308 with self.test_session() as sess: 1309 v = variables.Variable([10.0], name="v") 1310 # Run the initializer NOW to avoid the 0.5s overhead of the first Run() 1311 # call, which throws the test timing off in fastbuild mode. 1312 variables.global_variables_initializer().run() 1313 # Create a saver that will keep the last 2 checkpoints plus one every 0.7 1314 # seconds. 1315 start_time = time.time() 1316 mock_time.time.return_value = start_time 1317 save = saver_module.Saver( 1318 { 1319 "v": v 1320 }, max_to_keep=2, keep_checkpoint_every_n_hours=0.7 / 3600) 1321 self.assertEqual([], save.last_checkpoints) 1322 1323 # Wait till 1 seconds have elapsed so s1 will be old enough to keep. 1324 # sleep may return early, don't trust it. 1325 mock_time.time.return_value = start_time + 1.0 1326 s1 = save.save(sess, os.path.join(save_dir, "s1")) 1327 self.assertEqual([s1], save.last_checkpoints) 1328 1329 s2 = save.save(sess, os.path.join(save_dir, "s2")) 1330 self.assertEqual([s1, s2], save.last_checkpoints) 1331 1332 # We now have 2 'last_checkpoints': [s1, s2]. The next call to Save(), 1333 # would normally delete s1, because max_to_keep is 2. However, s1 is 1334 # older than 0.7s so we must keep it. 1335 s3 = save.save(sess, os.path.join(save_dir, "s3")) 1336 self.assertEqual([s2, s3], save.last_checkpoints) 1337 1338 # s1 should still be here, we are Not checking now to reduce time 1339 # variance in the test. 1340 1341 # We now have 2 'last_checkpoints': [s2, s3], and s1 on disk. The next 1342 # call to Save(), will delete s2, because max_to_keep is 2, and because 1343 # we already kept the old s1. s2 is very close in time to s1 so it gets 1344 # deleted. 1345 s4 = save.save(sess, os.path.join(save_dir, "s4")) 1346 self.assertEqual([s3, s4], save.last_checkpoints) 1347 1348 # Check that s1 is still here, but s2 is gone. 1349 self.assertTrue(saver_module.checkpoint_exists(s1)) 1350 self.assertFalse(saver_module.checkpoint_exists(s2)) 1351 self.assertTrue(saver_module.checkpoint_exists(s3)) 1352 self.assertTrue(saver_module.checkpoint_exists(s4)) 1353 1354 1355@test_util.with_c_api 1356class SaveRestoreWithVariableNameMap(test.TestCase): 1357 1358 def _testNonReshape(self, variable_op): 1359 save_path = os.path.join(self.get_temp_dir(), "non_reshape") 1360 1361 with self.test_session(graph=ops_lib.Graph()) as sess: 1362 # Build a graph with 2 parameter nodes, and Save and 1363 # Restore nodes for them. 1364 v0 = variable_op(10.0, name="v0") 1365 v1 = variable_op(20.0, name="v1") 1366 save = saver_module.Saver({"save_prefix/v0": v0, "save_prefix/v1": v1}) 1367 self.evaluate(variables.global_variables_initializer()) 1368 1369 # Check that the parameter nodes have been initialized. 1370 self.assertEqual(10.0, self.evaluate(v0)) 1371 self.assertEqual(20.0, self.evaluate(v1)) 1372 1373 # Save the initialized values in the file at "save_path" 1374 # Use a variable name map to set the saved tensor names 1375 val = save.save(sess, save_path) 1376 self.assertTrue(isinstance(val, six.string_types)) 1377 self.assertEqual(save_path, val) 1378 1379 # Verify that the original names are not in the Saved file 1380 save = saver_module.Saver({"v0": v0, "v1": v1}) 1381 with self.assertRaisesOpError("not found in checkpoint"): 1382 save.restore(sess, save_path) 1383 1384 # Verify that the mapped names are present in the Saved file and can be 1385 # Restored using remapped names. 1386 with self.test_session(graph=ops_lib.Graph()) as sess: 1387 v0 = variable_op(-1.0, name="v0") 1388 v1 = variable_op(-1.0, name="v1") 1389 1390 if context.in_graph_mode(): 1391 with self.assertRaisesOpError("uninitialized"): 1392 self.evaluate(v0) 1393 with self.assertRaisesOpError("uninitialized"): 1394 self.evaluate(v1) 1395 1396 save = saver_module.Saver({"save_prefix/v0": v0, "save_prefix/v1": v1}) 1397 save.restore(sess, save_path) 1398 1399 # Check that the parameter nodes have been restored. 1400 if context.in_graph_mode(): 1401 self.assertEqual(10.0, self.evaluate(v0)) 1402 self.assertEqual(20.0, self.evaluate(v1)) 1403 1404 # Add a prefix to the node names in the current graph and Restore using 1405 # remapped names. 1406 with self.test_session(graph=ops_lib.Graph()) as sess: 1407 v0 = variable_op(-1.0, name="restore_prefix/v0") 1408 v1 = variable_op(-1.0, name="restore_prefix/v1") 1409 1410 if context.in_graph_mode(): 1411 with self.assertRaisesOpError("uninitialized"): 1412 self.evaluate(v0) 1413 with self.assertRaisesOpError("uninitialized"): 1414 self.evaluate(v1) 1415 1416 # Restore the saved values in the parameter nodes. 1417 save = saver_module.Saver({"save_prefix/v0": v0, "save_prefix/v1": v1}) 1418 save.restore(sess, save_path) 1419 1420 # Check that the parameter nodes have been restored. 1421 self.assertEqual(10.0, self.evaluate(v0)) 1422 self.assertEqual(20.0, self.evaluate(v1)) 1423 1424 @test_util.run_in_graph_and_eager_modes() 1425 def testNonReshapeResourceVariable(self): 1426 self._testNonReshape(resource_variable_ops.ResourceVariable) 1427 1428 def testNonReshapeVariable(self): 1429 self._testNonReshape(variables.Variable) 1430 1431 1432@test_util.with_c_api 1433class LatestCheckpointWithRelativePaths(test.TestCase): 1434 1435 @staticmethod 1436 @contextlib.contextmanager 1437 def tempWorkingDir(temppath): 1438 cwd = os.getcwd() 1439 os.chdir(temppath) 1440 try: 1441 yield 1442 finally: 1443 os.chdir(cwd) 1444 1445 @staticmethod 1446 @contextlib.contextmanager 1447 def tempDir(): 1448 tempdir = tempfile.mkdtemp() 1449 try: 1450 yield tempdir 1451 finally: 1452 shutil.rmtree(tempdir) 1453 1454 def testNameCollision(self): 1455 # Make sure we have a clean directory to work in. 1456 with self.tempDir() as tempdir: 1457 # Jump to that directory until this test is done. 1458 with self.tempWorkingDir(tempdir): 1459 # Save training snapshots to a relative path. 1460 traindir = "train/" 1461 os.mkdir(traindir) 1462 # Collides with the default name of the checkpoint state file. 1463 filepath = os.path.join(traindir, "checkpoint") 1464 1465 with self.test_session() as sess: 1466 unused_a = variables.Variable(0.0) # So that Saver saves something. 1467 variables.global_variables_initializer().run() 1468 1469 # Should fail. 1470 saver = saver_module.Saver(sharded=False) 1471 with self.assertRaisesRegexp(ValueError, "collides with"): 1472 saver.save(sess, filepath) 1473 1474 # Succeeds: the file will be named "checkpoint-<step>". 1475 saver.save(sess, filepath, global_step=1) 1476 self.assertIsNotNone(saver_module.latest_checkpoint(traindir)) 1477 1478 # Succeeds: the file will be named "checkpoint-<i>-of-<n>". 1479 saver = saver_module.Saver(sharded=True) 1480 saver.save(sess, filepath) 1481 self.assertIsNotNone(saver_module.latest_checkpoint(traindir)) 1482 1483 # Succeeds: the file will be named "checkpoint-<step>-<i>-of-<n>". 1484 saver = saver_module.Saver(sharded=True) 1485 saver.save(sess, filepath, global_step=1) 1486 self.assertIsNotNone(saver_module.latest_checkpoint(traindir)) 1487 1488 def testRelativePath(self): 1489 # Make sure we have a clean directory to work in. 1490 with self.tempDir() as tempdir: 1491 1492 # Jump to that directory until this test is done. 1493 with self.tempWorkingDir(tempdir): 1494 1495 # Save training snapshots to a relative path. 1496 traindir = "train/" 1497 os.mkdir(traindir) 1498 1499 filename = "snapshot" 1500 filepath = os.path.join(traindir, filename) 1501 1502 with self.test_session() as sess: 1503 # Build a simple graph. 1504 v0 = variables.Variable(0.0) 1505 inc = v0.assign_add(1.0) 1506 1507 save = saver_module.Saver({"v0": v0}) 1508 1509 # Record a short training history. 1510 variables.global_variables_initializer().run() 1511 save.save(sess, filepath, global_step=0) 1512 inc.eval() 1513 save.save(sess, filepath, global_step=1) 1514 inc.eval() 1515 save.save(sess, filepath, global_step=2) 1516 1517 with self.test_session() as sess: 1518 # Build a new graph with different initialization. 1519 v0 = variables.Variable(-1.0) 1520 1521 # Create a new saver. 1522 save = saver_module.Saver({"v0": v0}) 1523 variables.global_variables_initializer().run() 1524 1525 # Get the most recent checkpoint name from the training history file. 1526 name = saver_module.latest_checkpoint(traindir) 1527 self.assertIsNotNone(name) 1528 1529 # Restore "v0" from that checkpoint. 1530 save.restore(sess, name) 1531 self.assertEqual(v0.eval(), 2.0) 1532 1533 1534@test_util.with_c_api 1535class CheckpointStateTest(test.TestCase): 1536 1537 def _get_test_dir(self, dirname): 1538 test_dir = os.path.join(self.get_temp_dir(), dirname) 1539 gfile.MakeDirs(test_dir) 1540 return test_dir 1541 1542 def testAbsPath(self): 1543 save_dir = self._get_test_dir("abs_paths") 1544 abs_path = os.path.join(save_dir, "model-0") 1545 ckpt = saver_module.generate_checkpoint_state_proto(save_dir, abs_path) 1546 self.assertEqual(ckpt.model_checkpoint_path, abs_path) 1547 self.assertTrue(os.path.isabs(ckpt.model_checkpoint_path)) 1548 self.assertEqual(len(ckpt.all_model_checkpoint_paths), 1) 1549 self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path) 1550 1551 def testRelPath(self): 1552 train_dir = "train" 1553 model = os.path.join(train_dir, "model-0") 1554 # model_checkpoint_path should have no "train" directory part. 1555 new_rel_path = "model-0" 1556 ckpt = saver_module.generate_checkpoint_state_proto(train_dir, model) 1557 self.assertEqual(ckpt.model_checkpoint_path, new_rel_path) 1558 self.assertEqual(len(ckpt.all_model_checkpoint_paths), 1) 1559 self.assertEqual(ckpt.all_model_checkpoint_paths[-1], new_rel_path) 1560 1561 def testAllModelCheckpointPaths(self): 1562 save_dir = self._get_test_dir("all_models_test") 1563 abs_path = os.path.join(save_dir, "model-0") 1564 for paths in [None, [], ["model-2"]]: 1565 ckpt = saver_module.generate_checkpoint_state_proto( 1566 save_dir, abs_path, all_model_checkpoint_paths=paths) 1567 self.assertEqual(ckpt.model_checkpoint_path, abs_path) 1568 self.assertTrue(os.path.isabs(ckpt.model_checkpoint_path)) 1569 self.assertEqual( 1570 len(ckpt.all_model_checkpoint_paths), len(paths) if paths else 1) 1571 self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path) 1572 1573 def testUpdateCheckpointState(self): 1574 save_dir = self._get_test_dir("update_checkpoint_state") 1575 os.chdir(save_dir) 1576 # Make a temporary train directory. 1577 train_dir = "train" 1578 os.mkdir(train_dir) 1579 abs_path = os.path.join(save_dir, "model-0") 1580 rel_path = os.path.join("train", "model-2") 1581 saver_module.update_checkpoint_state( 1582 train_dir, rel_path, all_model_checkpoint_paths=[abs_path, rel_path]) 1583 ckpt = saver_module.get_checkpoint_state(train_dir) 1584 self.assertEqual(ckpt.model_checkpoint_path, rel_path) 1585 self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2) 1586 self.assertEqual(ckpt.all_model_checkpoint_paths[-1], rel_path) 1587 self.assertEqual(ckpt.all_model_checkpoint_paths[0], abs_path) 1588 1589 def testUpdateCheckpointStateSaveRelativePaths(self): 1590 save_dir = self._get_test_dir("update_checkpoint_state") 1591 os.chdir(save_dir) 1592 abs_path2 = os.path.join(save_dir, "model-2") 1593 rel_path2 = "model-2" 1594 abs_path0 = os.path.join(save_dir, "model-0") 1595 rel_path0 = "model-0" 1596 saver_module._update_checkpoint_state( # pylint: disable=protected-access 1597 save_dir=save_dir, 1598 model_checkpoint_path=abs_path2, 1599 all_model_checkpoint_paths=[rel_path0, abs_path2], 1600 save_relative_paths=True) 1601 1602 # File should contain relative paths. 1603 file_content = file_io.read_file_to_string( 1604 os.path.join(save_dir, "checkpoint")) 1605 ckpt = CheckpointState() 1606 text_format.Merge(file_content, ckpt) 1607 self.assertEqual(ckpt.model_checkpoint_path, rel_path2) 1608 self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2) 1609 self.assertEqual(ckpt.all_model_checkpoint_paths[-1], rel_path2) 1610 self.assertEqual(ckpt.all_model_checkpoint_paths[0], rel_path0) 1611 1612 # get_checkpoint_state should return absolute paths. 1613 ckpt = saver_module.get_checkpoint_state(save_dir) 1614 self.assertEqual(ckpt.model_checkpoint_path, abs_path2) 1615 self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2) 1616 self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path2) 1617 self.assertEqual(ckpt.all_model_checkpoint_paths[0], abs_path0) 1618 1619 def testCheckPointStateFailsWhenIncomplete(self): 1620 save_dir = self._get_test_dir("checkpoint_state_fails_when_incomplete") 1621 os.chdir(save_dir) 1622 ckpt_path = os.path.join(save_dir, "checkpoint") 1623 ckpt_file = open(ckpt_path, "w") 1624 ckpt_file.write("") 1625 ckpt_file.close() 1626 with self.assertRaises(ValueError): 1627 saver_module.get_checkpoint_state(save_dir) 1628 1629 def testCheckPointCompletesRelativePaths(self): 1630 save_dir = self._get_test_dir("checkpoint_completes_relative_paths") 1631 os.chdir(save_dir) 1632 ckpt_path = os.path.join(save_dir, "checkpoint") 1633 ckpt_file = open(ckpt_path, "w") 1634 ckpt_file.write(""" 1635 model_checkpoint_path: "./model.ckpt-687529" 1636 all_model_checkpoint_paths: "./model.ckpt-687500" 1637 all_model_checkpoint_paths: "./model.ckpt-687529" 1638 """) 1639 ckpt_file.close() 1640 ckpt = saver_module.get_checkpoint_state(save_dir) 1641 self.assertEqual(ckpt.model_checkpoint_path, 1642 os.path.join(save_dir, "./model.ckpt-687529")) 1643 self.assertEqual(ckpt.all_model_checkpoint_paths[0], 1644 os.path.join(save_dir, "./model.ckpt-687500")) 1645 self.assertEqual(ckpt.all_model_checkpoint_paths[1], 1646 os.path.join(save_dir, "./model.ckpt-687529")) 1647 1648 1649@test_util.with_c_api 1650class MetaGraphTest(test.TestCase): 1651 1652 def _get_test_dir(self, dirname): 1653 test_dir = os.path.join(self.get_temp_dir(), dirname) 1654 gfile.MakeDirs(test_dir) 1655 return test_dir 1656 1657 def testAddCollectionDef(self): 1658 test_dir = self._get_test_dir("good_collection") 1659 filename = os.path.join(test_dir, "metafile") 1660 with self.test_session(): 1661 # Creates a graph. 1662 v0 = variables.Variable(1.0, name="v0") 1663 control_flow_ops.cond( 1664 math_ops.less(v0, 10), lambda: math_ops.add(v0, 1), 1665 lambda: math_ops.subtract(v0, 1)) 1666 control_flow_ops.while_loop(lambda i: math_ops.less(i, 10), 1667 lambda i: math_ops.add(i, 1), [v0]) 1668 var = variables.Variable(constant_op.constant(0, dtype=dtypes.int64)) 1669 count_up_to = var.count_up_to(3) 1670 input_queue = data_flow_ops.FIFOQueue( 1671 30, dtypes.float32, shared_name="collection_queue") 1672 qr = queue_runner_impl.QueueRunner(input_queue, [count_up_to]) 1673 variables.global_variables_initializer() 1674 # Creates a saver. 1675 save = saver_module.Saver({"v0": v0}) 1676 # Adds a set of collections. 1677 ops_lib.add_to_collection("int_collection", 3) 1678 ops_lib.add_to_collection("float_collection", 3.5) 1679 ops_lib.add_to_collection("string_collection", "hello") 1680 ops_lib.add_to_collection("variable_collection", v0) 1681 # Add QueueRunners. 1682 queue_runner_impl.add_queue_runner(qr) 1683 # Adds user_defined proto in three formats: string, bytes and Any. 1684 queue_runner = queue_runner_pb2.QueueRunnerDef(queue_name="test_queue") 1685 ops_lib.add_to_collection("user_defined_string_collection", 1686 str(queue_runner)) 1687 ops_lib.add_to_collection("user_defined_bytes_collection", 1688 queue_runner.SerializeToString()) 1689 any_buf = Any() 1690 any_buf.Pack(queue_runner) 1691 ops_lib.add_to_collection("user_defined_any_collection", any_buf) 1692 1693 # Generates MetaGraphDef. 1694 meta_graph_def = save.export_meta_graph(filename) 1695 self.assertTrue(meta_graph_def.HasField("saver_def")) 1696 self.assertTrue(meta_graph_def.HasField("graph_def")) 1697 self.assertTrue(meta_graph_def.HasField("meta_info_def")) 1698 self.assertNotEqual(meta_graph_def.meta_info_def.tensorflow_version, "") 1699 self.assertNotEqual(meta_graph_def.meta_info_def.tensorflow_git_version, 1700 "") 1701 collection_def = meta_graph_def.collection_def 1702 self.assertEqual(len(collection_def), 12) 1703 1704 with ops_lib.Graph().as_default(): 1705 # Restores from MetaGraphDef. 1706 new_saver = saver_module.import_meta_graph(filename) 1707 # Generates a new MetaGraphDef. 1708 new_meta_graph_def = new_saver.export_meta_graph() 1709 # It should be the same as the original. 1710 1711 test_util.assert_meta_graph_protos_equal( 1712 self, meta_graph_def, new_meta_graph_def) 1713 1714 def testAddCollectionDefFails(self): 1715 with self.test_session(): 1716 # Creates a graph. 1717 v0 = variables.Variable(10.0, name="v0") 1718 # Creates a saver. 1719 save = saver_module.Saver({"v0": v0}) 1720 # Generates MetaGraphDef. 1721 meta_graph_def = meta_graph_pb2.MetaGraphDef() 1722 1723 # Verifies that collection with unsupported key will not be added. 1724 ops_lib.add_to_collection(save, 3) 1725 save._add_collection_def(meta_graph_def, save) 1726 self.assertEqual(len(meta_graph_def.collection_def), 0) 1727 1728 # Verifies that collection where item type does not match expected 1729 # type will not be added. 1730 ops_lib.add_to_collection("int_collection", 3) 1731 ops_lib.add_to_collection("int_collection", 3.5) 1732 save._add_collection_def(meta_graph_def, "int_collection") 1733 self.assertEqual(len(meta_graph_def.collection_def), 0) 1734 1735 def _testMultiSaverCollectionSave(self, test_dir): 1736 filename = os.path.join(test_dir, "metafile") 1737 saver0_ckpt = os.path.join(test_dir, "saver0.ckpt") 1738 saver1_ckpt = os.path.join(test_dir, "saver1.ckpt") 1739 with self.test_session(graph=ops_lib.Graph()) as sess: 1740 # Creates a graph. 1741 v0 = variables.Variable([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], name="v0") 1742 v1 = variables.Variable(11.0, name="v1") 1743 # Creates 2 savers. 1744 saver0 = saver_module.Saver({"v0": v0}, name="saver0") 1745 saver1 = saver_module.Saver({"v1": v1}, name="saver1") 1746 ops_lib.add_to_collection("savers", saver0) 1747 ops_lib.add_to_collection("savers", saver1) 1748 variables.global_variables_initializer().run() 1749 # Saves to different checkpoints. 1750 saver0.save(sess, saver0_ckpt) 1751 saver1.save(sess, saver1_ckpt) 1752 # Generates MetaGraphDef. 1753 meta_graph_def = saver_module.export_meta_graph(filename) 1754 meta_graph_def0 = saver0.export_meta_graph() 1755 meta_graph_def1 = saver1.export_meta_graph() 1756 1757 # Verifies that there is no saver_def in meta_graph_def. 1758 self.assertFalse(meta_graph_def.HasField("saver_def")) 1759 # Verifies that there is saver_def in meta_graph_def0 and 1. 1760 self.assertTrue(meta_graph_def0.HasField("saver_def")) 1761 self.assertTrue(meta_graph_def1.HasField("saver_def")) 1762 1763 # Verifies SAVERS is saved as bytes_list for meta_graph_def. 1764 collection_def = meta_graph_def.collection_def["savers"] 1765 kind = collection_def.WhichOneof("kind") 1766 self.assertEqual(kind, "bytes_list") 1767 # Verifies that there are 2 entries in SAVERS collection. 1768 savers = getattr(collection_def, kind) 1769 self.assertEqual(2, len(savers.value)) 1770 1771 # Verifies SAVERS collection is saved as bytes_list for meta_graph_def0. 1772 collection_def = meta_graph_def0.collection_def["savers"] 1773 kind = collection_def.WhichOneof("kind") 1774 self.assertEqual(kind, "bytes_list") 1775 # Verifies that there are 2 entries in SAVERS collection. 1776 savers = getattr(collection_def, kind) 1777 self.assertEqual(2, len(savers.value)) 1778 1779 def _testMultiSaverCollectionRestore(self, test_dir): 1780 filename = os.path.join(test_dir, "metafile") 1781 saver0_ckpt = os.path.join(test_dir, "saver0.ckpt") 1782 saver1_ckpt = os.path.join(test_dir, "saver1.ckpt") 1783 with self.test_session(graph=ops_lib.Graph()) as sess: 1784 # Imports from meta_graph. 1785 saver_module.import_meta_graph(filename) 1786 # Retrieves SAVERS collection. Verifies there are 2 entries. 1787 savers = ops_lib.get_collection("savers") 1788 self.assertEqual(2, len(savers)) 1789 # Retrieves saver0. Verifies that new_saver0 can restore v0, but not v1. 1790 new_saver0 = savers[0] 1791 new_saver0.restore(sess, saver0_ckpt) 1792 v0 = sess.graph.get_tensor_by_name("v0:0") 1793 v1 = sess.graph.get_tensor_by_name("v1:0") 1794 self.assertAllEqual([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], v0.eval()) 1795 self.assertEqual([3, 2], v0.get_shape()) 1796 self.assertEqual([], v1.get_shape()) 1797 with self.assertRaisesWithPredicateMatch( 1798 errors_impl.OpError, lambda e: "uninitialized value v1" in e.message): 1799 sess.run(v1) 1800 # Retrieves saver1. Verifies that new_saver1 can restore v1. 1801 new_saver1 = savers[1] 1802 new_saver1.restore(sess, saver1_ckpt) 1803 v1 = sess.graph.get_tensor_by_name("v1:0") 1804 self.assertEqual(11.0, v1.eval()) 1805 1806 def testMultiSaverCollection(self): 1807 test_dir = self._get_test_dir("saver_collection") 1808 self._testMultiSaverCollectionSave(test_dir) 1809 self._testMultiSaverCollectionRestore(test_dir) 1810 1811 def testClearExtraneousSavers(self): 1812 test_dir = self._get_test_dir("clear_extraneous_savers") 1813 filename = os.path.join(test_dir, "metafile") 1814 saver0_ckpt = os.path.join(test_dir, "saver0.ckpt") 1815 saver1_ckpt = os.path.join(test_dir, "saver1.ckpt") 1816 with self.test_session(graph=ops_lib.Graph()) as sess: 1817 # Creates a graph. 1818 v0 = variables.Variable([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], name="v0") 1819 v1 = variables.Variable(11.0, name="v1") 1820 1821 # Creates 2 savers. 1822 saver0 = saver_module.Saver({"v0": v0}, name="saver0") 1823 saver1 = saver_module.Saver({"v1": v1}, name="saver1") 1824 ops_lib.add_to_collection("savers", saver0) 1825 ops_lib.add_to_collection("savers", saver1) 1826 variables.global_variables_initializer().run() 1827 1828 # Saves to different checkpoints. 1829 saver0.save(sess, saver0_ckpt) 1830 saver1.save(sess, saver1_ckpt) 1831 1832 # Generates MetaGraphDef. 1833 meta_graph_def = saver_module.export_meta_graph(filename) 1834 meta_graph_def0 = saver0.export_meta_graph() 1835 meta_graph_def1 = saver1.export_meta_graph(clear_extraneous_savers=True) 1836 1837 # Verifies that there is no saver_def in meta_graph_def. 1838 self.assertFalse(meta_graph_def.HasField("saver_def")) 1839 # Verifies that there is saver_def in meta_graph_def0 and 1. 1840 self.assertTrue(meta_graph_def0.HasField("saver_def")) 1841 self.assertTrue(meta_graph_def1.HasField("saver_def")) 1842 1843 # Verifies SAVERS is saved as bytes_list for meta_graph_def. 1844 collection_def = meta_graph_def.collection_def["savers"] 1845 kind = collection_def.WhichOneof("kind") 1846 self.assertEqual(kind, "bytes_list") 1847 1848 # Verifies that there are 2 entries in SAVERS collection. 1849 savers = getattr(collection_def, kind) 1850 self.assertEqual(2, len(savers.value)) 1851 1852 # Verifies SAVERS collection is saved as bytes_list for meta_graph_def1. 1853 collection_def = meta_graph_def1.collection_def["savers"] 1854 kind = collection_def.WhichOneof("kind") 1855 self.assertEqual(kind, "bytes_list") 1856 1857 # Verifies that there is 1 entry in SAVERS collection. 1858 savers = getattr(collection_def, kind) 1859 self.assertEqual(1, len(savers.value)) 1860 1861 # Verifies that saver0 graph nodes are omitted from the saver1 export 1862 self.assertEqual(29, len(meta_graph_def0.graph_def.node)) 1863 self.assertEqual(19, len(meta_graph_def1.graph_def.node)) 1864 1865 def testBinaryAndTextFormat(self): 1866 test_dir = self._get_test_dir("binary_and_text") 1867 filename = os.path.join(test_dir, "metafile") 1868 with self.test_session(graph=ops_lib.Graph()): 1869 # Creates a graph. 1870 variables.Variable(10.0, name="v0") 1871 # Exports the graph as binary format. 1872 saver_module.export_meta_graph(filename, as_text=False) 1873 with self.test_session(graph=ops_lib.Graph()): 1874 # Imports the binary format graph. 1875 saver = saver_module.import_meta_graph(filename) 1876 self.assertIsNotNone(saver) 1877 # Exports the graph as text format. 1878 saver.export_meta_graph(filename, as_text=True) 1879 with self.test_session(graph=ops_lib.Graph()): 1880 # Imports the text format graph. 1881 saver_module.import_meta_graph(filename) 1882 # Writes wrong contents to the file. 1883 graph_io.write_graph(saver.as_saver_def(), 1884 os.path.dirname(filename), 1885 os.path.basename(filename)) 1886 with self.test_session(graph=ops_lib.Graph()): 1887 # Import should fail. 1888 with self.assertRaisesWithPredicateMatch(IOError, 1889 lambda e: "Cannot parse file"): 1890 saver_module.import_meta_graph(filename) 1891 # Deletes the file 1892 gfile.Remove(filename) 1893 with self.assertRaisesWithPredicateMatch(IOError, 1894 lambda e: "does not exist"): 1895 saver_module.import_meta_graph(filename) 1896 1897 def testSliceVariable(self): 1898 test_dir = self._get_test_dir("slice_saver") 1899 filename = os.path.join(test_dir, "metafile") 1900 with self.test_session(): 1901 v1 = variables.Variable([20.0], name="v1") 1902 v2 = variables.Variable([20.0], name="v2") 1903 v2._set_save_slice_info( 1904 variables.Variable.SaveSliceInfo("v1", [1], [0], [1])) 1905 1906 # The names are different and will work. 1907 slice_saver = saver_module.Saver({"first": v1, "second": v2}) 1908 variables.global_variables_initializer().run() 1909 # Exports to meta_graph 1910 meta_graph_def = slice_saver.export_meta_graph(filename) 1911 1912 with ops_lib.Graph().as_default(): 1913 # Restores from MetaGraphDef. 1914 new_saver = saver_module.import_meta_graph(filename) 1915 self.assertIsNotNone(new_saver) 1916 # Generates a new MetaGraphDef. 1917 new_meta_graph_def = new_saver.export_meta_graph() 1918 # It should be the same as the original. 1919 test_util.assert_meta_graph_protos_equal(self, meta_graph_def, 1920 new_meta_graph_def) 1921 1922 def _testGraphExtensionSave(self, test_dir): 1923 filename = os.path.join(test_dir, "metafile") 1924 saver0_ckpt = os.path.join(test_dir, "saver0.ckpt") 1925 # Creates an inference graph. 1926 # Hidden 1 1927 images = constant_op.constant(1.2, dtypes.float32, shape=[100, 28]) 1928 with ops_lib.name_scope("hidden1"): 1929 weights = variables.Variable( 1930 random_ops.truncated_normal( 1931 [28, 128], stddev=1.0 / math.sqrt(float(28))), 1932 name="weights") 1933 # The use of control_flow_ops.cond here is purely for adding test coverage 1934 # the save and restore of control flow context (which doesn't make any 1935 # sense here from a machine learning perspective). The typical biases is 1936 # a simple Variable without the conditions. 1937 biases = variables.Variable( 1938 control_flow_ops.cond( 1939 math_ops.less(random.random(), 0.5), 1940 lambda: array_ops.ones([128]), lambda: array_ops.zeros([128])), 1941 name="biases") 1942 hidden1 = nn_ops.relu(math_ops.matmul(images, weights) + biases) 1943 # Hidden 2 1944 with ops_lib.name_scope("hidden2"): 1945 weights = variables.Variable( 1946 random_ops.truncated_normal( 1947 [128, 32], stddev=1.0 / math.sqrt(float(128))), 1948 name="weights") 1949 1950 # The use of control_flow_ops.while_loop here is purely for adding test 1951 # coverage the save and restore of control flow context (which doesn't 1952 # make any sense here from a machine learning perspective). The typical 1953 # biases is a simple Variable without the conditions. 1954 def loop_cond(it, _): 1955 return it < 2 1956 1957 def loop_body(it, biases): 1958 biases += constant_op.constant(0.1, shape=[32]) 1959 return it + 1, biases 1960 1961 _, biases = control_flow_ops.while_loop( 1962 loop_cond, loop_body, 1963 [constant_op.constant(0), variables.Variable(array_ops.zeros([32]))]) 1964 hidden2 = nn_ops.relu(math_ops.matmul(hidden1, weights) + biases) 1965 # Linear 1966 with ops_lib.name_scope("softmax_linear"): 1967 weights = variables.Variable( 1968 random_ops.truncated_normal( 1969 [32, 10], stddev=1.0 / math.sqrt(float(32))), 1970 name="weights") 1971 biases = variables.Variable(array_ops.zeros([10]), name="biases") 1972 logits = math_ops.matmul(hidden2, weights) + biases 1973 ops_lib.add_to_collection("logits", logits) 1974 init_all_op = variables.global_variables_initializer() 1975 1976 with self.test_session() as sess: 1977 # Initializes all the variables. 1978 sess.run(init_all_op) 1979 # Runs to logit. 1980 sess.run(logits) 1981 # Creates a saver. 1982 saver0 = saver_module.Saver() 1983 saver0.save(sess, saver0_ckpt) 1984 # Generates MetaGraphDef. 1985 saver0.export_meta_graph(filename) 1986 1987 def _testGraphExtensionRestore(self, test_dir): 1988 filename = os.path.join(test_dir, "metafile") 1989 train_filename = os.path.join(test_dir, "train_metafile") 1990 saver0_ckpt = os.path.join(test_dir, "saver0.ckpt") 1991 with self.test_session(graph=ops_lib.Graph()) as sess: 1992 # Restores from MetaGraphDef. 1993 new_saver = saver_module.import_meta_graph(filename) 1994 # Generates a new MetaGraphDef. 1995 new_saver.export_meta_graph() 1996 # Restores from checkpoint. 1997 new_saver.restore(sess, saver0_ckpt) 1998 # Adds loss and train. 1999 labels = constant_op.constant(0, dtypes.int32, shape=[100], name="labels") 2000 batch_size = array_ops.size(labels) 2001 labels = array_ops.expand_dims(labels, 1) 2002 indices = array_ops.expand_dims(math_ops.range(0, batch_size), 1) 2003 concated = array_ops.concat([indices, labels], 1) 2004 onehot_labels = sparse_ops.sparse_to_dense( 2005 concated, array_ops.stack([batch_size, 10]), 1.0, 0.0) 2006 logits = ops_lib.get_collection("logits")[0] 2007 cross_entropy = nn_ops.softmax_cross_entropy_with_logits( 2008 labels=onehot_labels, logits=logits, name="xentropy") 2009 loss = math_ops.reduce_mean(cross_entropy, name="xentropy_mean") 2010 2011 summary.scalar("loss", loss) 2012 # Creates the gradient descent optimizer with the given learning rate. 2013 optimizer = gradient_descent.GradientDescentOptimizer(0.01) 2014 2015 # Runs train_op. 2016 train_op = optimizer.minimize(loss) 2017 ops_lib.add_to_collection("train_op", train_op) 2018 2019 # Runs train_op. 2020 sess.run(train_op) 2021 2022 # Generates MetaGraphDef. 2023 saver_module.export_meta_graph(train_filename) 2024 2025 def _testRestoreFromTrainGraphWithControlContext(self, test_dir): 2026 train_filename = os.path.join(test_dir, "train_metafile") 2027 saver0_ckpt = os.path.join(test_dir, "saver0.ckpt") 2028 with self.test_session(graph=ops_lib.Graph()) as sess: 2029 # Restores from MetaGraphDef. 2030 new_saver = saver_module.import_meta_graph(train_filename) 2031 # Restores from checkpoint. 2032 new_saver.restore(sess, saver0_ckpt) 2033 train_op = ops_lib.get_collection("train_op")[0] 2034 sess.run(train_op) 2035 2036 def testGraphExtension(self): 2037 test_dir = self._get_test_dir("graph_extension") 2038 self._testGraphExtensionSave(test_dir) 2039 self._testGraphExtensionRestore(test_dir) 2040 self._testRestoreFromTrainGraphWithControlContext(test_dir) 2041 2042 def testStrippedOpListDef(self): 2043 with self.test_session(): 2044 # Creates a graph. 2045 v0 = variables.Variable(0.0) 2046 var = variables.Variable(10.0) 2047 math_ops.add(v0, var) 2048 2049 @function.Defun(dtypes.float32) 2050 def minus_one(x): 2051 return x - 1 2052 2053 minus_one(array_ops.identity(v0)) 2054 save = saver_module.Saver({"v0": v0}) 2055 variables.global_variables_initializer() 2056 2057 # Generates MetaGraphDef. 2058 meta_graph_def = save.export_meta_graph() 2059 ops = [o.name for o in meta_graph_def.meta_info_def.stripped_op_list.op] 2060 if save._write_version is saver_pb2.SaverDef.V1: 2061 self.assertEqual(ops, [ 2062 "Add", "Assign", "Const", "Identity", "NoOp", "RestoreV2", 2063 "SaveSlices", "Sub", "VariableV2" 2064 ]) 2065 else: 2066 self.assertEqual(ops, [ 2067 "Add", "Assign", "Const", "Identity", "NoOp", "RestoreV2", "SaveV2", 2068 "Sub", "VariableV2" 2069 ]) 2070 2071 # Test calling stripped_op_list_for_graph directly 2072 op_list = meta_graph.stripped_op_list_for_graph(meta_graph_def.graph_def) 2073 self.assertEqual(ops, [o.name for o in op_list.op]) 2074 for o in op_list.op: 2075 self.assertEqual(o.summary, "") 2076 self.assertEqual(o.description, "") 2077 2078 def testStripDefaultValuedAttrs(self): 2079 """Verifies that default valued attrs are stripped, unless disabled.""" 2080 2081 # With strip_default_attrs enabled, attributes "T" (float32) and "Tout" 2082 # (complex64) in the "Complex" op must be removed. 2083 with self.test_session(): 2084 real_num = variables.Variable(1.0, dtype=dtypes.float32, name="real") 2085 imag_num = variables.Variable(2.0, dtype=dtypes.float32, name="imag") 2086 math_ops.complex(real_num, imag_num, name="complex") 2087 2088 save = saver_module.Saver({"real_num": real_num, "imag_num": imag_num}) 2089 variables.global_variables_initializer() 2090 2091 meta_graph_def = save.export_meta_graph(strip_default_attrs=True) 2092 node_def = test_util.get_node_def_from_graph("complex", 2093 meta_graph_def.graph_def) 2094 self.assertNotIn("T", node_def.attr) 2095 self.assertNotIn("Tout", node_def.attr) 2096 2097 # With strip_default_attrs disabled, attributes "T" (float32) and "Tout" 2098 # (complex64) in the "Complex" op must *not* be removed, even if they map 2099 # to their defaults. 2100 with self.test_session(graph=ops_lib.Graph()): 2101 real_num = variables.Variable(1.0, dtype=dtypes.float32, name="real") 2102 imag_num = variables.Variable(2.0, dtype=dtypes.float32, name="imag") 2103 math_ops.complex(real_num, imag_num, name="complex") 2104 2105 save = saver_module.Saver({"real_num": real_num, "imag_num": imag_num}) 2106 variables.global_variables_initializer() 2107 2108 meta_graph_def = save.export_meta_graph(strip_default_attrs=False) 2109 node_def = test_util.get_node_def_from_graph("complex", 2110 meta_graph_def.graph_def) 2111 self.assertIn("T", node_def.attr) 2112 self.assertIn("Tout", node_def.attr) 2113 2114 def testImportIntoNamescope(self): 2115 # Test that we can import a meta graph into a namescope. 2116 test_dir = self._get_test_dir("import_into_namescope") 2117 filename = os.path.join(test_dir, "ckpt") 2118 image = array_ops.placeholder(dtypes.float32, [None, 784], name="image") 2119 label = array_ops.placeholder(dtypes.float32, [None, 10], name="label") 2120 with session.Session() as sess: 2121 weights = variables.Variable( 2122 random_ops.random_uniform([784, 10]), name="weights") 2123 bias = variables.Variable(array_ops.zeros([10]), name="bias") 2124 logit = nn_ops.relu(math_ops.matmul(image, weights) + bias, name="logits") 2125 nn_ops.softmax(logit, name="prediction") 2126 cost = nn_ops.softmax_cross_entropy_with_logits(labels=label, 2127 logits=logit, name="cost") 2128 adam.AdamOptimizer().minimize(cost, name="optimize") 2129 saver = saver_module.Saver() 2130 sess.run(variables.global_variables_initializer()) 2131 saver.save(sess, filename) 2132 2133 graph = ops_lib.Graph() 2134 with session.Session(graph=graph) as sess: 2135 new_saver = saver_module.import_meta_graph( 2136 filename + ".meta", graph=graph, import_scope="new_model") 2137 new_saver.restore(sess, filename) 2138 sess.run(["new_model/optimize"], { 2139 "new_model/image:0": np.random.random([1, 784]), 2140 "new_model/label:0": np.random.randint( 2141 10, size=[1, 10]) 2142 }) 2143 2144 def testClearDevicesOnImport(self): 2145 # Test that we import a graph without its devices and run successfully. 2146 with ops_lib.Graph().as_default(): 2147 with ops_lib.device("/job:ps/replica:0/task:0/device:GPU:0"): 2148 image = array_ops.placeholder(dtypes.float32, [None, 784], name="image") 2149 label = array_ops.placeholder(dtypes.float32, [None, 10], name="label") 2150 weights = variables.Variable( 2151 random_ops.random_uniform([784, 10]), name="weights") 2152 bias = variables.Variable(array_ops.zeros([10]), name="bias") 2153 logit = nn_ops.relu(math_ops.matmul(image, weights) + bias) 2154 nn_ops.softmax(logit, name="prediction") 2155 cost = nn_ops.softmax_cross_entropy_with_logits(labels=label, 2156 logits=logit) 2157 adam.AdamOptimizer().minimize(cost, name="optimize") 2158 meta_graph_def = saver_module.export_meta_graph() 2159 2160 with session.Session(graph=ops_lib.Graph()) as sess: 2161 saver_module.import_meta_graph( 2162 meta_graph_def, clear_devices=False, import_scope="new_model") 2163 # Device refers to GPU, which is not available here. 2164 with self.assertRaises(errors_impl.InvalidArgumentError): 2165 sess.run(variables.global_variables_initializer()) 2166 2167 with session.Session(graph=ops_lib.Graph()) as sess: 2168 saver_module.import_meta_graph( 2169 meta_graph_def, clear_devices=True, import_scope="new_model") 2170 sess.run(variables.global_variables_initializer()) 2171 sess.run(["new_model/optimize"], { 2172 "new_model/image:0": np.random.random([1, 784]), 2173 "new_model/label:0": np.random.randint( 2174 10, size=[1, 10]) 2175 }) 2176 2177 def testClearDevicesOnExport(self): 2178 # Test that we export a graph without its devices and run successfully. 2179 with ops_lib.Graph().as_default(): 2180 with ops_lib.device("/job:ps/replica:0/task:0/device:GPU:0"): 2181 image = array_ops.placeholder(dtypes.float32, [None, 784], name="image") 2182 label = array_ops.placeholder(dtypes.float32, [None, 10], name="label") 2183 weights = variables.Variable( 2184 random_ops.random_uniform([784, 10]), name="weights") 2185 bias = variables.Variable(array_ops.zeros([10]), name="bias") 2186 logit = nn_ops.relu(math_ops.matmul(image, weights) + bias) 2187 nn_ops.softmax(logit, name="prediction") 2188 cost = nn_ops.softmax_cross_entropy_with_logits(labels=label, 2189 logits=logit) 2190 adam.AdamOptimizer().minimize(cost, name="optimize") 2191 meta_graph_def = saver_module.export_meta_graph(clear_devices=True) 2192 graph_io.write_graph(meta_graph_def, self.get_temp_dir(), 2193 "meta_graph.pbtxt") 2194 2195 with session.Session(graph=ops_lib.Graph()) as sess: 2196 saver_module.import_meta_graph(meta_graph_def, import_scope="new_model") 2197 sess.run(variables.global_variables_initializer()) 2198 sess.run(["new_model/optimize"], { 2199 "new_model/image:0": np.random.random([1, 784]), 2200 "new_model/label:0": np.random.randint( 2201 10, size=[1, 10]) 2202 }) 2203 2204 def testPreserveDatasetAndFunctions(self): 2205 with ops_lib.Graph().as_default() as g: 2206 dataset = dataset_ops.Dataset.range(10).map(lambda x: x * x) 2207 iterator = dataset.make_one_shot_iterator() 2208 next_element = iterator.get_next() 2209 _ = array_ops.identity(next_element, name="output") 2210 2211 # Generate three MetaGraphDef protos using different code paths. 2212 meta_graph_def_simple = saver_module.export_meta_graph() 2213 meta_graph_def_devices_cleared = saver_module.export_meta_graph( 2214 clear_devices=True) 2215 meta_graph_def_from_graph_def = saver_module.export_meta_graph( 2216 clear_devices=True, graph_def=g.as_graph_def()) 2217 2218 for meta_graph_def in [meta_graph_def_simple, 2219 meta_graph_def_devices_cleared, 2220 meta_graph_def_from_graph_def]: 2221 with session.Session(graph=ops_lib.Graph()) as sess: 2222 saver_module.import_meta_graph(meta_graph_def, import_scope="new_model") 2223 sess.run(variables.global_variables_initializer()) 2224 for i in range(10): 2225 self.assertEqual(i * i, sess.run("new_model/output:0")) 2226 with self.assertRaises(errors.OutOfRangeError): 2227 sess.run("new_model/output:0") 2228 2229 2230@test_util.with_c_api 2231class CheckpointReaderTest(test.TestCase): 2232 2233 _WRITE_VERSION = saver_pb2.SaverDef.V1 2234 2235 def testDebugString(self): 2236 # Builds a graph. 2237 v0 = variables.Variable( 2238 [[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32, name="v0") 2239 v1 = variables.Variable( 2240 [[[1], [2]], [[3], [4]], [[5], [6]]], dtype=dtypes.float32, name="v1") 2241 init_all_op = variables.global_variables_initializer() 2242 save = saver_module.Saver( 2243 { 2244 "v0": v0, 2245 "v1": v1 2246 }, write_version=self._WRITE_VERSION) 2247 save_path = os.path.join(self.get_temp_dir(), 2248 "ckpt_for_debug_string" + str(self._WRITE_VERSION)) 2249 with self.test_session() as sess: 2250 sess.run(init_all_op) 2251 # Saves a checkpoint. 2252 save.save(sess, save_path) 2253 2254 # Creates a reader. 2255 reader = pywrap_tensorflow.NewCheckpointReader(save_path) 2256 # Verifies that the tensors exist. 2257 self.assertTrue(reader.has_tensor("v0")) 2258 self.assertTrue(reader.has_tensor("v1")) 2259 debug_string = reader.debug_string() 2260 # Verifies that debug string contains the right strings. 2261 self.assertTrue(compat.as_bytes("v0 (DT_FLOAT) [2,3]") in debug_string) 2262 self.assertTrue(compat.as_bytes("v1 (DT_FLOAT) [3,2,1]") in debug_string) 2263 # Verifies get_variable_to_shape_map() returns the correct information. 2264 var_map = reader.get_variable_to_shape_map() 2265 self.assertEqual([2, 3], var_map["v0"]) 2266 self.assertEqual([3, 2, 1], var_map["v1"]) 2267 # Verifies get_tensor() returns the tensor value. 2268 v0_tensor = reader.get_tensor("v0") 2269 v1_tensor = reader.get_tensor("v1") 2270 self.assertAllEqual(v0.eval(), v0_tensor) 2271 self.assertAllEqual(v1.eval(), v1_tensor) 2272 # Verifies get_tensor() fails for non-existent tensors. 2273 with self.assertRaisesRegexp(errors.NotFoundError, 2274 "v3 not found in checkpoint"): 2275 reader.get_tensor("v3") 2276 2277 def testNonexistentPath(self): 2278 with self.assertRaisesRegexp(errors.NotFoundError, 2279 "Unsuccessful TensorSliceReader"): 2280 pywrap_tensorflow.NewCheckpointReader("non-existent") 2281 2282 2283@test_util.with_c_api 2284class CheckpointReaderForV2Test(CheckpointReaderTest): 2285 _WRITE_VERSION = saver_pb2.SaverDef.V2 2286 2287 2288@test_util.with_c_api 2289class WriteGraphTest(test.TestCase): 2290 2291 def _get_test_dir(self, dirname): 2292 test_dir = os.path.join(self.get_temp_dir(), dirname) 2293 gfile.MakeDirs(test_dir) 2294 return test_dir 2295 2296 def testWriteGraph(self): 2297 test_dir = self._get_test_dir("write_graph_dir") 2298 variables.Variable([[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32, name="v0") 2299 path = graph_io.write_graph(ops_lib.get_default_graph(), 2300 os.path.join(test_dir, "l1"), "graph.pbtxt") 2301 truth = os.path.join(test_dir, "l1", "graph.pbtxt") 2302 self.assertEqual(path, truth) 2303 self.assertTrue(os.path.exists(path)) 2304 2305 def testRecursiveCreate(self): 2306 test_dir = self._get_test_dir("deep_dir") 2307 variables.Variable([[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32, name="v0") 2308 path = graph_io.write_graph(ops_lib.get_default_graph().as_graph_def(), 2309 os.path.join(test_dir, "l1", "l2", "l3"), 2310 "graph.pbtxt") 2311 truth = os.path.join(test_dir, "l1", "l2", "l3", "graph.pbtxt") 2312 self.assertEqual(path, truth) 2313 self.assertTrue(os.path.exists(path)) 2314 2315 2316@test_util.with_c_api 2317class SaverUtilsTest(test.TestCase): 2318 2319 def setUp(self): 2320 self._base_dir = os.path.join(self.get_temp_dir(), "saver_utils_test") 2321 gfile.MakeDirs(self._base_dir) 2322 2323 def tearDown(self): 2324 gfile.DeleteRecursively(self._base_dir) 2325 2326 def testCheckpointExists(self): 2327 for sharded in (False, True): 2328 for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1): 2329 with self.test_session(graph=ops_lib.Graph()) as sess: 2330 unused_v = variables.Variable(1.0, name="v") 2331 variables.global_variables_initializer().run() 2332 saver = saver_module.Saver(sharded=sharded, write_version=version) 2333 2334 path = os.path.join(self._base_dir, "%s-%s" % (sharded, version)) 2335 self.assertFalse( 2336 saver_module.checkpoint_exists(path)) # Not saved yet. 2337 2338 ckpt_prefix = saver.save(sess, path) 2339 self.assertTrue(saver_module.checkpoint_exists(ckpt_prefix)) 2340 2341 ckpt_prefix = saver_module.latest_checkpoint(self._base_dir) 2342 self.assertTrue(saver_module.checkpoint_exists(ckpt_prefix)) 2343 2344 def testGetCheckpointMtimes(self): 2345 prefixes = [] 2346 for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1): 2347 with self.test_session(graph=ops_lib.Graph()) as sess: 2348 unused_v = variables.Variable(1.0, name="v") 2349 variables.global_variables_initializer().run() 2350 saver = saver_module.Saver(write_version=version) 2351 prefixes.append( 2352 saver.save(sess, os.path.join(self._base_dir, str(version)))) 2353 2354 mtimes = saver_module.get_checkpoint_mtimes(prefixes) 2355 self.assertEqual(2, len(mtimes)) 2356 self.assertTrue(mtimes[1] >= mtimes[0]) 2357 2358 2359@test_util.with_c_api 2360class ScopedGraphTest(test.TestCase): 2361 2362 def _get_test_dir(self, dirname): 2363 test_dir = os.path.join(self.get_temp_dir(), dirname) 2364 gfile.MakeDirs(test_dir) 2365 return test_dir 2366 2367 def _testScopedSave(self, test_dir, exported_filename, ckpt_filename): 2368 graph = ops_lib.Graph() 2369 with graph.as_default(): 2370 # Creates an inference graph. 2371 # Hidden 1 2372 images = constant_op.constant( 2373 1.2, dtypes.float32, shape=[100, 28], name="images") 2374 with ops_lib.name_scope("hidden1"): 2375 weights1 = variables.Variable( 2376 random_ops.truncated_normal( 2377 [28, 128], stddev=1.0 / math.sqrt(float(28))), 2378 name="weights") 2379 # The use of control_flow_ops.cond here is purely for adding test 2380 # coverage the save and restore of control flow context (which doesn't 2381 # make any sense here from a machine learning perspective). The typical 2382 # biases is a simple Variable without the conditions. 2383 biases1 = variables.Variable( 2384 control_flow_ops.cond( 2385 math_ops.less(random.random(), 0.5), 2386 lambda: array_ops.ones([128]), lambda: array_ops.zeros([128])), 2387 name="biases") 2388 hidden1 = nn_ops.relu(math_ops.matmul(images, weights1) + biases1) 2389 2390 # Hidden 2 2391 with ops_lib.name_scope("hidden2"): 2392 weights2 = variables.Variable( 2393 random_ops.truncated_normal( 2394 [128, 32], stddev=1.0 / math.sqrt(float(128))), 2395 name="weights") 2396 2397 # The use of control_flow_ops.while_loop here is purely for adding test 2398 # coverage the save and restore of control flow context (which doesn't 2399 # make any sense here from a machine learning perspective). The typical 2400 # biases is a simple Variable without the conditions. 2401 def loop_cond(it, _): 2402 return it < 2 2403 2404 def loop_body(it, biases2): 2405 biases2 += constant_op.constant(0.1, shape=[32]) 2406 return it + 1, biases2 2407 2408 _, biases2 = control_flow_ops.while_loop(loop_cond, loop_body, [ 2409 constant_op.constant(0), variables.Variable(array_ops.zeros([32])) 2410 ]) 2411 hidden2 = nn_ops.relu(math_ops.matmul(hidden1, weights2) + biases2) 2412 # Linear 2413 with ops_lib.name_scope("softmax_linear"): 2414 weights3 = variables.Variable( 2415 random_ops.truncated_normal( 2416 [32, 10], stddev=1.0 / math.sqrt(float(32))), 2417 name="weights") 2418 biases3 = variables.Variable(array_ops.zeros([10]), name="biases") 2419 logits = math_ops.matmul(hidden2, weights3) + biases3 2420 ops_lib.add_to_collection("logits", logits) 2421 2422 # Adds user_defined proto in three formats: string, bytes and Any. 2423 # Any proto should just pass through. 2424 queue_runner = queue_runner_pb2.QueueRunnerDef(queue_name="test_queue") 2425 ops_lib.add_to_collection("user_defined_string_collection", 2426 str(queue_runner)) 2427 ops_lib.add_to_collection("user_defined_bytes_collection", 2428 queue_runner.SerializeToString()) 2429 any_buf = Any() 2430 any_buf.Pack(queue_runner) 2431 ops_lib.add_to_collection("user_defined_any_collection", any_buf) 2432 2433 _, var_list = meta_graph.export_scoped_meta_graph( 2434 filename=os.path.join(test_dir, exported_filename), 2435 graph=ops_lib.get_default_graph(), 2436 export_scope="hidden1") 2437 self.assertEqual(["biases:0", "weights:0"], sorted(var_list.keys())) 2438 2439 with self.test_session(graph=graph) as sess: 2440 sess.run(variables.global_variables_initializer()) 2441 saver = saver_module.Saver(var_list=var_list, max_to_keep=1) 2442 saver.save(sess, os.path.join(test_dir, ckpt_filename), write_state=False) 2443 2444 def _testScopedRestore(self, test_dir, exported_filename, 2445 new_exported_filename, ckpt_filename): 2446 graph = ops_lib.Graph() 2447 # Create all the missing inputs. 2448 with graph.as_default(): 2449 new_image = constant_op.constant( 2450 1.2, dtypes.float32, shape=[100, 28], name="images") 2451 var_list = meta_graph.import_scoped_meta_graph( 2452 os.path.join(test_dir, exported_filename), 2453 graph=graph, 2454 input_map={"$unbound_inputs_images": new_image}, 2455 import_scope="new_hidden1") 2456 self.assertEqual(["biases:0", "weights:0"], sorted(var_list.keys())) 2457 hidden1 = graph.as_graph_element("new_hidden1/Relu:0") 2458 weights1 = graph.as_graph_element("new_hidden1/weights:0") 2459 biases1 = graph.as_graph_element("new_hidden1/biases:0") 2460 2461 with graph.as_default(): 2462 # Hidden 2 2463 with ops_lib.name_scope("hidden2"): 2464 weights = variables.Variable( 2465 random_ops.truncated_normal( 2466 [128, 32], stddev=1.0 / math.sqrt(float(128))), 2467 name="weights") 2468 2469 # The use of control_flow_ops.while_loop here is purely for adding test 2470 # coverage the save and restore of control flow context (which doesn't 2471 # make any sense here from a machine learning perspective). The typical 2472 # biases is a simple Variable without the conditions. 2473 def loop_cond(it, _): 2474 return it < 2 2475 2476 def loop_body(it, biases): 2477 biases += constant_op.constant(0.1, shape=[32]) 2478 return it + 1, biases 2479 2480 _, biases = control_flow_ops.while_loop(loop_cond, loop_body, [ 2481 constant_op.constant(0), variables.Variable(array_ops.zeros([32])) 2482 ]) 2483 hidden2 = nn_ops.relu(math_ops.matmul(hidden1, weights) + biases) 2484 # Linear 2485 with ops_lib.name_scope("softmax_linear"): 2486 weights = variables.Variable( 2487 random_ops.truncated_normal( 2488 [32, 10], stddev=1.0 / math.sqrt(float(32))), 2489 name="weights") 2490 biases = variables.Variable(array_ops.zeros([10]), name="biases") 2491 logits = math_ops.matmul(hidden2, weights) + biases 2492 ops_lib.add_to_collection("logits", logits) 2493 2494 # The rest of the variables. 2495 rest_variables = list( 2496 set(variables.global_variables()) - set(var_list.keys())) 2497 init_rest_op = variables.initialize_variables(rest_variables) 2498 2499 with self.test_session(graph=graph) as sess: 2500 saver = saver_module.Saver(var_list=var_list, max_to_keep=1) 2501 saver.restore(sess, os.path.join(test_dir, ckpt_filename)) 2502 # Verify that we have restored weights1 and biases1. 2503 sess.run([weights1, biases1]) 2504 # Initialize the rest of the variables and run logits. 2505 sess.run(init_rest_op) 2506 sess.run(logits) 2507 2508 # Verifies that we can save the subgraph under "hidden1" and restore it 2509 # into "new_hidden1" in the new graph. 2510 def testScopedSaveAndRestore(self): 2511 test_dir = self._get_test_dir("scoped_export_import") 2512 ckpt_filename = "ckpt" 2513 self._testScopedSave(test_dir, "exported_hidden1.pbtxt", ckpt_filename) 2514 self._testScopedRestore(test_dir, "exported_hidden1.pbtxt", 2515 "exported_new_hidden1.pbtxt", ckpt_filename) 2516 2517 # Verifies that we can copy the subgraph under "hidden1" and copy it 2518 # to different name scope in the same graph or different graph. 2519 def testCopyScopedGraph(self): 2520 test_dir = self._get_test_dir("scoped_copy") 2521 saver0_ckpt = os.path.join(test_dir, "saver0.ckpt") 2522 graph1 = ops_lib.Graph() 2523 with graph1.as_default(): 2524 with ops_lib.name_scope("hidden1"): 2525 images = constant_op.constant( 2526 1.0, dtypes.float32, shape=[3, 2], name="images") 2527 weights1 = variables.Variable( 2528 [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], name="weights") 2529 biases1 = variables.Variable([0.1] * 3, name="biases") 2530 nn_ops.relu(math_ops.matmul(images, weights1) + biases1, name="relu") 2531 2532 # Run the graph and save scoped checkpoint. 2533 with self.test_session(graph=graph1) as sess: 2534 sess.run(variables.global_variables_initializer()) 2535 _, var_list_1 = meta_graph.export_scoped_meta_graph( 2536 export_scope="hidden1") 2537 saver = saver_module.Saver(var_list=var_list_1, max_to_keep=1) 2538 saver.save(sess, saver0_ckpt, write_state=False) 2539 2540 expected = np.reshape([[5.0999999, 7.0999999, 9.10000038] * 3], (3, 3)) 2541 2542 # Verifies copy to the same graph with the same name fails. 2543 with graph1.as_default(): 2544 with self.assertRaisesWithPredicateMatch( 2545 ValueError, lambda e: "need to be different" in str(e)): 2546 meta_graph.copy_scoped_meta_graph( 2547 from_scope="hidden1", to_scope="hidden1") 2548 2549 # Verifies copy to the same graph. 2550 with graph1.as_default(): 2551 var_list_2 = meta_graph.copy_scoped_meta_graph( 2552 from_scope="hidden1", to_scope="hidden2") 2553 2554 with self.test_session(graph=graph1) as sess: 2555 saver1 = saver_module.Saver(var_list=var_list_1, max_to_keep=1) 2556 saver1.restore(sess, saver0_ckpt) 2557 saver2 = saver_module.Saver(var_list=var_list_2, max_to_keep=1) 2558 saver2.restore(sess, saver0_ckpt) 2559 self.assertAllClose(expected, sess.run("hidden1/relu:0")) 2560 self.assertAllClose(expected, sess.run("hidden2/relu:0")) 2561 2562 # Verifies copy to differen graph. 2563 graph2 = ops_lib.Graph() 2564 new_var_list_1 = meta_graph.copy_scoped_meta_graph( 2565 from_scope="hidden1", 2566 to_scope="new_hidden1", 2567 from_graph=graph1, 2568 to_graph=graph2) 2569 2570 with self.test_session(graph=graph2) as sess: 2571 saver3 = saver_module.Saver(var_list=new_var_list_1, max_to_keep=1) 2572 saver3.restore(sess, saver0_ckpt) 2573 self.assertAllClose(expected, sess.run("new_hidden1/relu:0")) 2574 2575 def testExportGraphDefWithScope(self): 2576 test_dir = self._get_test_dir("export_graph_def") 2577 saver0_ckpt = os.path.join(test_dir, "saver0.ckpt") 2578 graph1 = ops_lib.Graph() 2579 with graph1.as_default(): 2580 with ops_lib.name_scope("hidden1"): 2581 images = constant_op.constant( 2582 1.0, dtypes.float32, shape=[3, 2], name="images") 2583 weights1 = variables.Variable( 2584 [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], name="weights") 2585 biases1 = variables.Variable([0.1] * 3, name="biases") 2586 nn_ops.relu(math_ops.matmul(images, weights1) + biases1, name="relu") 2587 2588 # Run the graph and save scoped checkpoint. 2589 with self.test_session(graph=graph1) as sess: 2590 sess.run(variables.global_variables_initializer()) 2591 _, var_list_1 = meta_graph.export_scoped_meta_graph( 2592 graph_def=graph1.as_graph_def(), export_scope="hidden1") 2593 saver = saver_module.Saver(var_list=var_list_1, max_to_keep=1) 2594 saver.save(sess, saver0_ckpt, write_state=False) 2595 2596 expected = np.reshape([[5.0999999, 7.0999999, 9.10000038] * 3], (3, 3)) 2597 2598 # Verifies that we can run successfully after restoring. 2599 graph2 = ops_lib.Graph() 2600 new_var_list_1 = meta_graph.copy_scoped_meta_graph( 2601 from_scope="hidden1", 2602 to_scope="new_hidden1", 2603 from_graph=graph1, 2604 to_graph=graph2) 2605 2606 with self.test_session(graph=graph2) as sess: 2607 saver3 = saver_module.Saver(var_list=new_var_list_1, max_to_keep=1) 2608 saver3.restore(sess, saver0_ckpt) 2609 self.assertAllClose(expected, sess.run("new_hidden1/relu:0")) 2610 2611 def testSerializeSaverWithScope(self): 2612 test_dir = self._get_test_dir("export_graph_def") 2613 saver1_ckpt = os.path.join(test_dir, "saver1.ckpt") 2614 saver2_ckpt = os.path.join(test_dir, "saver2.ckpt") 2615 graph = ops_lib.Graph() 2616 with graph.as_default(): 2617 with ops_lib.name_scope("hidden1"): 2618 variable1 = variables.Variable([1.0], name="variable1") 2619 saver1 = saver_module.Saver(var_list=[variable1]) 2620 graph.add_to_collection(ops_lib.GraphKeys.SAVERS, saver1) 2621 2622 with ops_lib.name_scope("hidden2"): 2623 variable2 = variables.Variable([2.0], name="variable2") 2624 saver2 = saver_module.Saver(var_list=[variable2], name="hidden2/") 2625 graph.add_to_collection(ops_lib.GraphKeys.SAVERS, saver2) 2626 2627 with self.test_session(graph=graph) as sess: 2628 variables.global_variables_initializer().run() 2629 saver1.save(sess, saver1_ckpt, write_state=False) 2630 saver2.save(sess, saver2_ckpt, write_state=False) 2631 2632 graph1 = ops_lib.Graph() 2633 var_dict1 = meta_graph.copy_scoped_meta_graph( 2634 from_scope="hidden1", 2635 to_scope="new_hidden1", 2636 from_graph=graph, 2637 to_graph=graph1) 2638 self.assertEqual(1, len(var_dict1)) 2639 2640 saver_list1 = graph1.get_collection(ops_lib.GraphKeys.SAVERS) 2641 self.assertEqual(1, len(saver_list1)) 2642 2643 with self.test_session(graph=graph1) as sess: 2644 saver_list1[0].restore(sess, saver1_ckpt) 2645 self.assertEqual(1.0, var_dict1["variable1:0"].eval()) 2646 2647 graph2 = ops_lib.Graph() 2648 var_dict2 = meta_graph.copy_scoped_meta_graph( 2649 from_scope="hidden2", 2650 to_scope="new_hidden2", 2651 from_graph=graph, 2652 to_graph=graph2) 2653 self.assertEqual(1, len(var_dict2)) 2654 2655 saver_list2 = graph2.get_collection(ops_lib.GraphKeys.SAVERS) 2656 self.assertEqual(1, len(saver_list2)) 2657 2658 with self.test_session(graph=graph2) as sess: 2659 saver_list2[0].restore(sess, saver2_ckpt) 2660 self.assertEqual(2.0, var_dict2["variable2:0"].eval()) 2661 2662 2663if __name__ == "__main__": 2664 test.main() 2665