1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Training helper that checkpoints models and computes summaries."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import contextlib
21import os
22import time
23
24from tensorflow.core.framework.summary_pb2 import Summary
25from tensorflow.core.util.event_pb2 import SessionLog
26from tensorflow.python.eager import context
27from tensorflow.python.framework import dtypes
28from tensorflow.python.framework import meta_graph
29from tensorflow.python.framework import ops
30from tensorflow.python.ops import control_flow_ops
31from tensorflow.python.ops import lookup_ops
32from tensorflow.python.ops import variables
33from tensorflow.python.platform import tf_logging as logging
34from tensorflow.python.summary import summary as _summary
35from tensorflow.python.training import coordinator
36from tensorflow.python.training import saver as saver_mod
37from tensorflow.python.training import session_manager as session_manager_mod
38from tensorflow.python.training import training_util
39from tensorflow.python.util import deprecation
40from tensorflow.python.util.tf_export import tf_export
41
42
43@tf_export("train.Supervisor")
44class Supervisor(object):
45  """A training helper that checkpoints models and computes summaries.
46
47  This class is deprecated. Please use
48  ${tf.train.MonitoredTrainingSession} instead.
49
50  The Supervisor is a small wrapper around a `Coordinator`, a `Saver`,
51  and a `SessionManager` that takes care of common needs of TensorFlow
52  training programs.
53
54  #### Use for a single program
55
56  ```python
57  with tf.Graph().as_default():
58    ...add operations to the graph...
59    # Create a Supervisor that will checkpoint the model in '/tmp/mydir'.
60    sv = Supervisor(logdir='/tmp/mydir')
61    # Get a TensorFlow session managed by the supervisor.
62    with sv.managed_session(FLAGS.master) as sess:
63      # Use the session to train the graph.
64      while not sv.should_stop():
65        sess.run(<my_train_op>)
66  ```
67
68  Within the `with sv.managed_session()` block all variables in the graph have
69  been initialized.  In addition, a few services have been started to
70  checkpoint the model and add summaries to the event log.
71
72  If the program crashes and is restarted, the managed session automatically
73  reinitialize variables from the most recent checkpoint.
74
75  The supervisor is notified of any exception raised by one of the services.
76  After an exception is raised, `should_stop()` returns `True`.  In that case
77  the training loop should also stop.  This is why the training loop has to
78  check for `sv.should_stop()`.
79
80  Exceptions that indicate that the training inputs have been exhausted,
81  `tf.errors.OutOfRangeError`, also cause `sv.should_stop()` to return `True`
82  but are not re-raised from the `with` block: they indicate a normal
83  termination.
84
85  #### Use for multiple replicas
86
87  To train with replicas you deploy the same program in a `Cluster`.
88  One of the tasks must be identified as the *chief*: the task that handles
89  initialization, checkpoints, summaries, and recovery.  The other tasks
90  depend on the *chief* for these services.
91
92  The only change you have to do to the single program code is to indicate
93  if the program is running as the *chief*.
94
95  ```python
96  # Choose a task as the chief. This could be based on server_def.task_index,
97  # or job_def.name, or job_def.tasks. It's entirely up to the end user.
98  # But there can be only one *chief*.
99  is_chief = (server_def.task_index == 0)
100  server = tf.train.Server(server_def)
101
102  with tf.Graph().as_default():
103    ...add operations to the graph...
104    # Create a Supervisor that uses log directory on a shared file system.
105    # Indicate if you are the 'chief'
106    sv = Supervisor(logdir='/shared_directory/...', is_chief=is_chief)
107    # Get a Session in a TensorFlow server on the cluster.
108    with sv.managed_session(server.target) as sess:
109      # Use the session to train the graph.
110      while not sv.should_stop():
111        sess.run(<my_train_op>)
112  ```
113
114  In the *chief* task, the `Supervisor` works exactly as in the first example
115  above.  In the other tasks `sv.managed_session()` waits for the Model to have
116  been initialized before returning a session to the training code.  The
117  non-chief tasks depend on the chief task for initializing the model.
118
119  If one of the tasks crashes and restarts, `managed_session()`
120  checks if the Model is initialized.  If yes, it just creates a session and
121  returns it to the training code that proceeds normally.  If the model needs
122  to be initialized, the chief task takes care of reinitializing it; the other
123  tasks just wait for the model to have been initialized.
124
125  NOTE: This modified program still works fine as a single program.
126  The single program marks itself as the chief.
127
128  #### What `master` string to use
129
130  Whether you are running on your machine or in the cluster you can use the
131  following values for the --master flag:
132
133  * Specifying `''` requests an in-process session that does not use RPC.
134
135  * Specifying `'local'` requests a session that uses the RPC-based
136    "Master interface" to run TensorFlow programs. See
137    @{tf.train.Server.create_local_server} for
138    details.
139
140  * Specifying `'grpc://hostname:port'` requests a session that uses
141    the RPC interface to a specific host, and also allows the in-process
142    master to access remote tensorflow workers. Often, it is
143    appropriate to pass `server.target` (for some `tf.train.Server`
144    named `server).
145
146  #### Advanced use
147
148  ##### Launching additional services
149
150  `managed_session()` launches the Checkpoint and Summary services (threads).
151  If you need more services to run you can simply launch them in the block
152  controlled by `managed_session()`.
153
154  Example: Start a thread to print losses.  We want this thread to run
155  every 60 seconds, so we launch it with `sv.loop()`.
156
157  ```python
158  ...
159  sv = Supervisor(logdir='/tmp/mydir')
160  with sv.managed_session(FLAGS.master) as sess:
161    sv.loop(60, print_loss, (sess, ))
162    while not sv.should_stop():
163      sess.run(my_train_op)
164  ```
165
166  ##### Launching fewer services
167
168  `managed_session()` launches the "summary" and "checkpoint" threads which use
169  either the optionally `summary_op` and `saver` passed to the constructor, or
170  default ones created automatically by the supervisor.  If you want to run
171  your own summary and checkpointing logic, disable these services by passing
172  `None` to the `summary_op` and `saver` parameters.
173
174  Example: Create summaries manually every 100 steps in the chief.
175
176  ```python
177  # Create a Supervisor with no automatic summaries.
178  sv = Supervisor(logdir='/tmp/mydir', is_chief=is_chief, summary_op=None)
179  # As summary_op was None, managed_session() does not start the
180  # summary thread.
181  with sv.managed_session(FLAGS.master) as sess:
182    for step in xrange(1000000):
183      if sv.should_stop():
184        break
185      if is_chief and step % 100 == 0:
186        # Create the summary every 100 chief steps.
187        sv.summary_computed(sess, sess.run(my_summary_op))
188      else:
189        # Train normally
190        sess.run(my_train_op)
191  ```
192
193  ##### Custom model initialization
194
195  `managed_session()` only supports initializing the model by running an
196  `init_op` or restoring from the latest checkpoint.  If you have special
197  initialization needs, see how to specify a `local_init_op` when creating the
198  supervisor.  You can also use the `SessionManager` directly to create a
199  session and check if it could be initialized automatically.
200  """
201
202  # Value to pass for the 'ready_op', 'init_op', 'summary_op', 'saver',
203  # and 'global_step' parameters of Supervisor.__init__() to indicate that
204  # the default behavior should be used.
205  USE_DEFAULT = 0
206
207  @deprecation.deprecated(None,
208                          "Please switch to tf.train.MonitoredTrainingSession")
209  def __init__(self,
210               graph=None,
211               ready_op=USE_DEFAULT,
212               ready_for_local_init_op=USE_DEFAULT,
213               is_chief=True,
214               init_op=USE_DEFAULT,
215               init_feed_dict=None,
216               local_init_op=USE_DEFAULT,
217               logdir=None,
218               summary_op=USE_DEFAULT,
219               saver=USE_DEFAULT,
220               global_step=USE_DEFAULT,
221               save_summaries_secs=120,
222               save_model_secs=600,
223               recovery_wait_secs=30,
224               stop_grace_secs=120,
225               checkpoint_basename="model.ckpt",
226               session_manager=None,
227               summary_writer=USE_DEFAULT,
228               init_fn=None):
229    """Create a `Supervisor`.
230
231    Args:
232      graph: A `Graph`.  The graph that the model will use.  Defaults to the
233        default `Graph`.  The supervisor may add operations to the graph before
234        creating a session, but the graph should not be modified by the caller
235        after passing it to the supervisor.
236      ready_op: 1-D string `Tensor`.  This tensor is evaluated by supervisors in
237        `prepare_or_wait_for_session()` to check if the model is ready to use.
238        The model is considered ready if it returns an empty array.  Defaults to
239        the tensor returned from `tf.report_uninitialized_variables()`  If
240        `None`, the model is not checked for readiness.
241      ready_for_local_init_op: 1-D string `Tensor`.  This tensor is evaluated by
242        supervisors in `prepare_or_wait_for_session()` to check if the model is
243        ready to run the local_init_op.
244        The model is considered ready if it returns an empty array.  Defaults to
245        the tensor returned from
246        `tf.report_uninitialized_variables(tf.global_variables())`. If `None`,
247        the model is not checked for readiness before running local_init_op.
248      is_chief: If True, create a chief supervisor in charge of initializing
249        and restoring the model.  If False, create a supervisor that relies
250        on a chief supervisor for inits and restore.
251      init_op: `Operation`.  Used by chief supervisors to initialize the model
252        when it can not be recovered.  Defaults to an `Operation` that
253        initializes all global variables.  If `None`, no initialization is done
254        automatically unless you pass a value for `init_fn`, see below.
255      init_feed_dict: A dictionary that maps `Tensor` objects to feed values.
256        This feed dictionary will be used when `init_op` is evaluated.
257      local_init_op: `Operation`. Used by all supervisors to run initializations
258        that should run for every new supervisor instance. By default these
259        are table initializers and initializers for local variables.
260        If `None`, no further per supervisor-instance initialization is
261        done automatically.
262      logdir: A string.  Optional path to a directory where to checkpoint the
263        model and log events for the visualizer.  Used by chief supervisors.
264        The directory will be created if it does not exist.
265      summary_op: An `Operation` that returns a Summary for the event logs.
266        Used by chief supervisors if a `logdir` was specified.  Defaults to the
267        operation returned from summary.merge_all().  If `None`, summaries are
268        not computed automatically.
269      saver: A Saver object.  Used by chief supervisors if a `logdir` was
270        specified.  Defaults to the saved returned by Saver().
271        If `None`, the model is not saved automatically.
272      global_step: An integer Tensor of size 1 that counts steps.  The value
273        from 'global_step' is used in summaries and checkpoint filenames.
274        Default to the op named 'global_step' in the graph if it exists, is of
275        rank 1, size 1, and of type tf.int32 or tf.int64.  If `None` the global
276        step is not recorded in summaries and checkpoint files.  Used by chief
277        supervisors if a `logdir` was specified.
278      save_summaries_secs: Number of seconds between the computation of
279        summaries for the event log.  Defaults to 120 seconds.  Pass 0 to
280        disable summaries.
281      save_model_secs: Number of seconds between the creation of model
282        checkpoints.  Defaults to 600 seconds.  Pass 0 to disable checkpoints.
283      recovery_wait_secs: Number of seconds between checks that the model
284        is ready.  Used by supervisors when waiting for a chief supervisor
285        to initialize or restore the model.  Defaults to 30 seconds.
286      stop_grace_secs: Grace period, in seconds, given to running threads to
287        stop when `stop()` is called.  Defaults to 120 seconds.
288      checkpoint_basename: The basename for checkpoint saving.
289      session_manager: `SessionManager`, which manages Session creation and
290        recovery. If it is `None`, a default `SessionManager` will be created
291        with the set of arguments passed in for backwards compatibility.
292      summary_writer: `SummaryWriter` to use or `USE_DEFAULT`.  Can be `None`
293        to indicate that no summaries should be written.
294      init_fn: Optional callable used to initialize the model. Called
295        after the optional `init_op` is called.  The callable must accept one
296        argument, the session being initialized.
297
298    Returns:
299      A `Supervisor`.
300
301    Raises:
302      RuntimeError: If called with eager execution enabled.
303
304    @compatibility(eager)
305    `Supervisor`s are not supported when eager execution is enabled.
306    @end_compatibility
307    """
308    if context.in_eager_mode():
309      raise RuntimeError("Supervisors are compatible with eager execution.")
310    # Set default values of arguments.
311    if graph is None:
312      graph = ops.get_default_graph()
313    with graph.as_default():
314      self._init_ready_op(
315          ready_op=ready_op, ready_for_local_init_op=ready_for_local_init_op)
316      self._init_init_op(init_op=init_op, init_feed_dict=init_feed_dict)
317      self._init_local_init_op(local_init_op=local_init_op)
318      self._init_saver(saver=saver)
319      self._init_summary_op(summary_op=summary_op)
320      self._init_global_step(global_step=global_step)
321    self._graph = graph
322    self._meta_graph_def = meta_graph.create_meta_graph_def(
323        graph_def=graph.as_graph_def(add_shapes=True),
324        saver_def=self._saver.saver_def if self._saver else None)
325    self._is_chief = is_chief
326    self._coord = coordinator.Coordinator()
327    self._recovery_wait_secs = recovery_wait_secs
328    self._stop_grace_secs = stop_grace_secs
329    self._init_fn = init_fn
330
331    # Set all attributes related to checkpointing and writing events to None.
332    # Afterwards, set them appropriately for chief supervisors, as these are
333    # the only supervisors that can write checkpoints and events.
334    self._logdir = None
335    self._save_summaries_secs = None
336    self._save_model_secs = None
337    self._save_path = None
338    self._summary_writer = None
339
340    if self._is_chief:
341      self._logdir = logdir
342      self._save_summaries_secs = save_summaries_secs
343      self._save_model_secs = save_model_secs
344      if self._logdir:
345        self._save_path = os.path.join(self._logdir, checkpoint_basename)
346      if summary_writer is Supervisor.USE_DEFAULT:
347        if self._logdir:
348          self._summary_writer = _summary.FileWriter(self._logdir)
349      else:
350        self._summary_writer = summary_writer
351      self._graph_added_to_summary = False
352
353    self._init_session_manager(session_manager=session_manager)
354    self._verify_setup()
355    # The graph is not allowed to change anymore.
356    graph.finalize()
357
358  def _init_session_manager(self, session_manager=None):
359    if session_manager is None:
360      self._session_manager = session_manager_mod.SessionManager(
361          local_init_op=self._local_init_op,
362          ready_op=self._ready_op,
363          ready_for_local_init_op=self._ready_for_local_init_op,
364          graph=self._graph,
365          recovery_wait_secs=self._recovery_wait_secs)
366    else:
367      self._session_manager = session_manager
368
369  def _get_first_op_from_collection(self, key):
370    """Returns the first `Operation` from a collection.
371
372    Args:
373      key: A string collection key.
374
375    Returns:
376      The first Op found in a collection, or `None` if the collection is empty.
377    """
378    try:
379      op_list = ops.get_collection(key)
380      if len(op_list) > 1:
381        logging.info("Found %d %s operations. Returning the first one.",
382                     len(op_list), key)
383      if op_list:
384        return op_list[0]
385    except LookupError:
386      pass
387
388    return None
389
390  def _init_ready_op(self,
391                     ready_op=USE_DEFAULT,
392                     ready_for_local_init_op=USE_DEFAULT):
393    """Initializes ready_op.
394
395    Args:
396      ready_op: `Tensor` to check if the model is initialized.
397        If it's set to USE_DEFAULT, creates an op that checks all
398        the variables are initialized.
399      ready_for_local_init_op: `Tensor` to check if the model is ready to run
400        local_init_op.
401        If it's set to USE_DEFAULT, creates an op that checks all
402        the global variables are initialized.
403    """
404    if ready_op is Supervisor.USE_DEFAULT:
405      ready_op = self._get_first_op_from_collection(ops.GraphKeys.READY_OP)
406      if ready_op is None:
407        ready_op = variables.report_uninitialized_variables()
408        ops.add_to_collection(ops.GraphKeys.READY_OP, ready_op)
409    self._ready_op = ready_op
410
411    # ready_for_local_init_op defaults to None for backward compatibility
412    if ready_for_local_init_op is Supervisor.USE_DEFAULT:
413      ready_for_local_init_op = self._get_first_op_from_collection(
414          ops.GraphKeys.READY_FOR_LOCAL_INIT_OP)
415    self._ready_for_local_init_op = ready_for_local_init_op
416
417  def _init_init_op(self, init_op=USE_DEFAULT, init_feed_dict=None):
418    """Initializes init_op.
419
420    Args:
421      init_op: `Operation` to initialize the variables. If set to USE_DEFAULT,
422        create an op that initializes all variables and tables.
423      init_feed_dict: A dictionary that maps `Tensor` objects to feed values.
424        This feed dictionary will be used when `init_op` is evaluated.
425    """
426    if init_op is Supervisor.USE_DEFAULT:
427      init_op = self._get_first_op_from_collection(ops.GraphKeys.INIT_OP)
428      if init_op is None:
429        init_op = variables.global_variables_initializer()
430        ops.add_to_collection(ops.GraphKeys.INIT_OP, init_op)
431    self._init_op = init_op
432    self._init_feed_dict = init_feed_dict
433
434  def _init_local_init_op(self, local_init_op=USE_DEFAULT):
435    """Initializes local_init_op.
436
437    Args:
438      local_init_op: `Operation` run for every new supervisor instance. If set
439      to USE_DEFAULT, use the first op from the GraphKeys.LOCAL_INIT_OP
440      collection. If the collection is empty, create an op that initializes
441      all local variables and all tables.
442    """
443    if local_init_op is Supervisor.USE_DEFAULT:
444      local_init_op = self._get_first_op_from_collection(
445          ops.GraphKeys.LOCAL_INIT_OP)
446      if local_init_op is None:
447        op_list = [
448            variables.local_variables_initializer(),
449            lookup_ops.tables_initializer()
450        ]
451        if op_list:
452          local_init_op = control_flow_ops.group(*op_list)
453          ops.add_to_collection(ops.GraphKeys.LOCAL_INIT_OP, local_init_op)
454    self._local_init_op = local_init_op
455
456  def _init_saver(self, saver=USE_DEFAULT):
457    """Initializes saver.
458
459    Args:
460      saver: A `Saver` object. If set to USE_DEFAULT, create one that
461        saves all the variables.
462    """
463    if saver is Supervisor.USE_DEFAULT:
464      saver = self._get_first_op_from_collection(ops.GraphKeys.SAVERS)
465      if saver is None and variables.global_variables():
466        saver = saver_mod.Saver()
467        ops.add_to_collection(ops.GraphKeys.SAVERS, saver)
468    self._saver = saver
469
470  def _init_summary_op(self, summary_op=USE_DEFAULT):
471    """Initializes summary_op.
472
473    Args:
474      summary_op: An Operation that returns a Summary for the event logs.
475        If set to USE_DEFAULT, create an op that merges all the summaries.
476    """
477    if summary_op is Supervisor.USE_DEFAULT:
478      summary_op = self._get_first_op_from_collection(ops.GraphKeys.SUMMARY_OP)
479      if summary_op is None:
480        summary_op = _summary.merge_all()
481        if summary_op is not None:
482          ops.add_to_collection(ops.GraphKeys.SUMMARY_OP, summary_op)
483    self._summary_op = summary_op
484
485  def _init_global_step(self, global_step=USE_DEFAULT):
486    """Initializes global_step.
487
488    Args:
489      global_step: An integer Tensor of size 1 that counts steps. If
490        set to USE_DEFAULT, creates global_step tensor.
491    """
492    if global_step is Supervisor.USE_DEFAULT:
493      global_step = self._get_first_op_from_collection(
494          ops.GraphKeys.GLOBAL_STEP)
495      if global_step is None:
496        global_step = self._default_global_step_tensor()
497        if global_step is not None:
498          ops.add_to_collection(ops.GraphKeys.GLOBAL_STEP, global_step)
499    self._global_step = global_step
500
501  @property
502  def is_chief(self):
503    """Return True if this is a chief supervisor.
504
505    Returns:
506      A bool.
507    """
508    return self._is_chief
509
510  @property
511  def session_manager(self):
512    """Return the SessionManager used by the Supervisor.
513
514    Returns:
515      A SessionManager object.
516    """
517    return self._session_manager
518
519  @property
520  def coord(self):
521    """Return the Coordinator used by the Supervisor.
522
523    The Coordinator can be useful if you want to run multiple threads
524    during your training.
525
526    Returns:
527      A Coordinator object.
528    """
529    return self._coord
530
531  @property
532  def init_op(self):
533    """Return the Init Op used by the supervisor.
534
535    Returns:
536      An Op or `None`.
537    """
538    return self._init_op
539
540  @property
541  def init_feed_dict(self):
542    """Return the feed dictionary used when evaluating the `init_op`.
543
544    Returns:
545      A feed dictionary or `None`.
546    """
547    return self._init_feed_dict
548
549  @property
550  def ready_op(self):
551    """Return the Ready Op used by the supervisor.
552
553    Returns:
554      An Op or `None`.
555    """
556    return self._ready_op
557
558  @property
559  def ready_for_local_init_op(self):
560    return self._ready_for_local_init_op
561
562  @property
563  def summary_writer(self):
564    """Return the SummaryWriter used by the chief supervisor.
565
566    Returns:
567      A SummaryWriter.
568    """
569    return self._summary_writer
570
571  @property
572  def summary_op(self):
573    """Return the Summary Tensor used by the chief supervisor.
574
575    Returns:
576      A string Tensor for the summary or `None`.
577    """
578    return self._summary_op
579
580  @property
581  def save_summaries_secs(self):
582    """Return the delay between summary computations.
583
584    Returns:
585      A timestamp.
586    """
587    return self._save_summaries_secs
588
589  @property
590  def global_step(self):
591    """Return the global_step Tensor used by the supervisor.
592
593    Returns:
594      An integer Tensor for the global_step.
595    """
596    return self._global_step
597
598  @property
599  def saver(self):
600    """Return the Saver used by the supervisor.
601
602    Returns:
603      A Saver object.
604    """
605    return self._saver
606
607  @property
608  def save_model_secs(self):
609    """Return the delay between checkpoints.
610
611    Returns:
612      A timestamp.
613    """
614    return self._save_model_secs
615
616  @property
617  def save_path(self):
618    """Return the save path used by the supervisor.
619
620    Returns:
621      A string.
622    """
623    return self._save_path
624
625  def _write_graph(self):
626    """Writes graph_def to `logdir` and adds it to summary if applicable."""
627    assert self._is_chief
628    if self._logdir:
629      training_util.write_graph(self._graph.as_graph_def(add_shapes=True),
630                                self._logdir, "graph.pbtxt")
631    if self._summary_writer and not self._graph_added_to_summary:
632      self._summary_writer.add_graph(self._graph)
633      self._summary_writer.add_meta_graph(self._meta_graph_def)
634      self._graph_added_to_summary = True
635
636  def start_standard_services(self, sess):
637    """Start the standard services for 'sess'.
638
639    This starts services in the background.  The services started depend
640    on the parameters to the constructor and may include:
641
642      - A Summary thread computing summaries every save_summaries_secs.
643      - A Checkpoint thread saving the model every save_model_secs.
644      - A StepCounter thread measure step time.
645
646    Args:
647      sess: A Session.
648
649    Returns:
650      A list of threads that are running the standard services.  You can use
651      the Supervisor's Coordinator to join these threads with:
652        sv.coord.Join(<list of threads>)
653
654    Raises:
655      RuntimeError: If called with a non-chief Supervisor.
656      ValueError: If not `logdir` was passed to the constructor as the
657        services need a log directory.
658    """
659    if not self._is_chief:
660      raise RuntimeError("Only chief supervisor can start standard services. "
661                         "Because only chief supervisors can write events.")
662
663    if not self._logdir:
664      logging.warning("Standard services need a 'logdir' "
665                      "passed to the SessionManager")
666      return
667
668    if self._global_step is not None and self._summary_writer:
669      # Only add the session log if we keep track of global step.
670      # TensorBoard cannot use START message for purging expired events
671      # if there is no step value.
672      current_step = training_util.global_step(sess, self._global_step)
673      self._summary_writer.add_session_log(
674          SessionLog(status=SessionLog.START),
675          current_step)
676
677    threads = []
678    if self._save_summaries_secs and self._summary_writer:
679      if self._summary_op is not None:
680        threads.append(SVSummaryThread(self, sess))
681      if self._global_step is not None:
682        threads.append(SVStepCounterThread(self, sess))
683    if self.saver and self._save_model_secs:
684      threads.append(SVTimerCheckpointThread(self, sess))
685    for t in threads:
686      t.start()
687    return threads
688
689  def prepare_or_wait_for_session(self, master="", config=None,
690                                  wait_for_checkpoint=False,
691                                  max_wait_secs=7200,
692                                  start_standard_services=True):
693    """Make sure the model is ready to be used.
694
695    Create a session on 'master', recovering or initializing the model as
696    needed, or wait for a session to be ready.  If running as the chief
697    and `start_standard_service` is set to True, also call the session
698    manager to start the standard services.
699
700    Args:
701      master: name of the TensorFlow master to use.  See the `tf.Session`
702        constructor for how this is interpreted.
703      config: Optional ConfigProto proto used to configure the session,
704        which is passed as-is to create the session.
705      wait_for_checkpoint: Whether we should wait for the availability of a
706        checkpoint before creating Session. Defaults to False.
707      max_wait_secs: Maximum time to wait for the session to become available.
708      start_standard_services: Whether to start the standard services and the
709        queue runners.
710
711    Returns:
712      A Session object that can be used to drive the model.
713    """
714    # For users who recreate the session with prepare_or_wait_for_session(), we
715    # need to clear the coordinator's stop_event so that threads managed by the
716    # coordinator can run.
717    self._coord.clear_stop()
718    if self._summary_writer:
719      self._summary_writer.reopen()
720
721    if self._is_chief:
722      sess = self._session_manager.prepare_session(
723          master, init_op=self.init_op, saver=self.saver,
724          checkpoint_dir=self._logdir, wait_for_checkpoint=wait_for_checkpoint,
725          max_wait_secs=max_wait_secs, config=config,
726          init_feed_dict=self._init_feed_dict, init_fn=self._init_fn)
727      self._write_graph()
728      if start_standard_services:
729        logging.info("Starting standard services.")
730        self.start_standard_services(sess)
731    else:
732      sess = self._session_manager.wait_for_session(master,
733                                                    config=config,
734                                                    max_wait_secs=max_wait_secs)
735    if start_standard_services:
736      logging.info("Starting queue runners.")
737      self.start_queue_runners(sess)
738    return sess
739
740  def start_queue_runners(self, sess, queue_runners=None):
741    """Start threads for `QueueRunners`.
742
743    Note that the queue runners collected in the graph key `QUEUE_RUNNERS`
744    are already started automatically when you create a session with the
745    supervisor, so unless you have non-collected queue runners to start
746    you do not need to call this explicitly.
747
748    Args:
749      sess: A `Session`.
750      queue_runners: A list of `QueueRunners`. If not specified, we'll use the
751        list of queue runners gathered in the graph under the key
752        `GraphKeys.QUEUE_RUNNERS`.
753
754    Returns:
755      The list of threads started for the `QueueRunners`.
756
757    Raises:
758      RuntimeError: If called with eager execution enabled.
759
760    @compatibility(eager)
761    Queues are not compatible with eager execution. To ingest data when eager
762    execution is enabled, use the `tf.data` API.
763    @end_compatibility
764    """
765    if context.in_eager_mode():
766      raise RuntimeError("Queues are not compatible with eager execution.")
767    if queue_runners is None:
768      queue_runners = self._graph.get_collection(ops.GraphKeys.QUEUE_RUNNERS)
769    threads = []
770    for qr in queue_runners:
771      threads.extend(qr.create_threads(sess, coord=self._coord, daemon=True,
772                                       start=True))
773    return threads
774
775  def loop(self, timer_interval_secs, target, args=None, kwargs=None):
776    """Start a LooperThread that calls a function periodically.
777
778    If `timer_interval_secs` is None the thread calls `target(*args, **kwargs)`
779    repeatedly.  Otherwise it calls it every `timer_interval_secs`
780    seconds.  The thread terminates when a stop is requested.
781
782    The started thread is added to the list of threads managed by the supervisor
783    so it does not need to be passed to the `stop()` method.
784
785    Args:
786      timer_interval_secs: Number. Time boundaries at which to call `target`.
787      target: A callable object.
788      args: Optional arguments to pass to `target` when calling it.
789      kwargs: Optional keyword arguments to pass to `target` when calling it.
790
791    Returns:
792      The started thread.
793    """
794    looper = coordinator.LooperThread(self._coord, timer_interval_secs,
795                                      target=target, args=args, kwargs=kwargs)
796    looper.start()
797    return looper
798
799  def stop(self,
800           threads=None,
801           close_summary_writer=True,
802           ignore_live_threads=False):
803    """Stop the services and the coordinator.
804
805    This does not close the session.
806
807    Args:
808      threads: Optional list of threads to join with the coordinator.  If
809        `None`, defaults to the threads running the standard services, the
810        threads started for `QueueRunners`, and the threads started by the
811        `loop()` method.  To wait on additional threads, pass the
812        list in this parameter.
813      close_summary_writer: Whether to close the `summary_writer`.  Defaults to
814        `True` if the summary writer was created by the supervisor, `False`
815        otherwise.
816      ignore_live_threads: If `True` ignores threads that remain running after
817        a grace period when joining threads via the coordinator, instead of
818        raising a RuntimeError.
819    """
820    self._coord.request_stop()
821    try:
822      # coord.join() re-raises the first reported exception; the "finally"
823      # block ensures that we clean up whether or not an exception was
824      # reported.
825      self._coord.join(
826          threads,
827          stop_grace_period_secs=self._stop_grace_secs,
828          ignore_live_threads=ignore_live_threads)
829    finally:
830      # Close the writer last, in case one of the running threads was using it.
831      if close_summary_writer and self._summary_writer:
832        # Stop messages are not logged with event.step,
833        # since the session may have already terminated.
834        self._summary_writer.add_session_log(SessionLog(status=SessionLog.STOP))
835        self._summary_writer.close()
836        self._graph_added_to_summary = False
837
838  def request_stop(self, ex=None):
839    """Request that the coordinator stop the threads.
840
841    See `Coordinator.request_stop()`.
842
843    Args:
844      ex: Optional `Exception`, or Python `exc_info` tuple as returned by
845        `sys.exc_info()`.  If this is the first call to `request_stop()` the
846        corresponding exception is recorded and re-raised from `join()`.
847    """
848    self._coord.request_stop(ex=ex)
849
850  def should_stop(self):
851    """Check if the coordinator was told to stop.
852
853    See `Coordinator.should_stop()`.
854
855    Returns:
856      True if the coordinator was told to stop, False otherwise.
857    """
858    return self._coord.should_stop()
859
860  def stop_on_exception(self):
861    """Context handler to stop the supervisor when an exception is raised.
862
863    See `Coordinator.stop_on_exception()`.
864
865    Returns:
866      A context handler.
867    """
868    return self._coord.stop_on_exception()
869
870  def wait_for_stop(self):
871    """Block waiting for the coordinator to stop."""
872    self._coord.wait_for_stop()
873
874  def summary_computed(self, sess, summary, global_step=None):
875    """Indicate that a summary was computed.
876
877    Args:
878      sess: A `Session` object.
879      summary: A Summary proto, or a string holding a serialized summary proto.
880      global_step: Int. global step this summary is associated with. If `None`,
881        it will try to fetch the current step.
882
883    Raises:
884      TypeError: if 'summary' is not a Summary proto or a string.
885      RuntimeError: if the Supervisor was created without a `logdir`.
886    """
887    if not self._summary_writer:
888      raise RuntimeError("Writing a summary requires a summary writer.")
889    if global_step is None and self.global_step is not None:
890      global_step = training_util.global_step(sess, self.global_step)
891    self._summary_writer.add_summary(summary, global_step)
892
893  def _default_global_step_tensor(self):
894    """Returns the global_step from the default graph.
895
896    Returns:
897      The global step `Tensor` or `None`.
898    """
899    try:
900      gs = ops.get_default_graph().get_tensor_by_name("global_step:0")
901      if gs.dtype.base_dtype in [dtypes.int32, dtypes.int64]:
902        return gs
903      else:
904        logging.warning("Found 'global_step' is not an int type: %s", gs.dtype)
905        return None
906    except KeyError:
907      return None
908
909  def _verify_setup(self):
910    """Check that all is good.
911
912    Raises:
913      ValueError: If something is not good.
914    """
915    # Not running as chief means that replicas are used.
916    # In that case all Variables must have their device set.
917    if not self._is_chief:
918      for op in self._graph.get_operations():
919        if op.type in ["Variable", "VariableV2"] and not op.device:
920          raise ValueError("When using replicas, all Variables must have "
921                           "their device set: %s" % op)
922
923  # pylint: disable=g-doc-return-or-yield,broad-except
924  @contextlib.contextmanager
925  def managed_session(self, master="", config=None,
926                      start_standard_services=True,
927                      close_summary_writer=True):
928    """Returns a context manager for a managed session.
929
930    This context manager creates and automatically recovers a session.  It
931    optionally starts the standard services that handle checkpoints and
932    summaries.  It monitors exceptions raised from the `with` block or from the
933    services and stops the supervisor as needed.
934
935    The context manager is typically used as follows:
936
937    ```python
938    def train():
939      sv = tf.train.Supervisor(...)
940      with sv.managed_session(<master>) as sess:
941        for step in xrange(..):
942          if sv.should_stop():
943            break
944          sess.run(<my training op>)
945          ...do other things needed at each training step...
946    ```
947
948    An exception raised from the `with` block or one of the service threads is
949    raised again when the block exits.  This is done after stopping all threads
950    and closing the session.  For example, an `AbortedError` exception, raised
951    in case of preemption of one of the workers in a distributed model, is
952    raised again when the block exits.
953
954    If you want to retry the training loop in case of preemption you can do it
955    as follows:
956
957    ```python
958    def main(...):
959      while True
960        try:
961          train()
962        except tf.errors.Aborted:
963          pass
964    ```
965
966    As a special case, exceptions used for control flow, such as
967    `OutOfRangeError` which reports that input queues are exhausted, are not
968    raised again from the `with` block: they indicate a clean termination of
969    the training loop and are considered normal termination.
970
971    Args:
972      master: name of the TensorFlow master to use.  See the `tf.Session`
973        constructor for how this is interpreted.
974      config: Optional `ConfigProto` proto used to configure the session.
975        Passed as-is to create the session.
976      start_standard_services: Whether to start the standard services,
977        such as checkpoint, summary and step counter.
978      close_summary_writer: Whether to close the summary writer when
979        closing the session.  Defaults to True.
980
981    Returns:
982      A context manager that yields a `Session` restored from the latest
983      checkpoint or initialized from scratch if not checkpoint exists.  The
984      session is closed when the `with` block exits.
985    """
986    try:
987      sess = self.prepare_or_wait_for_session(
988          master=master, config=config,
989          start_standard_services=start_standard_services)
990      yield sess
991    except Exception as e:
992      self.request_stop(e)
993    finally:
994      try:
995        # Request all the threads to stop and wait for them to do so.  Any
996        # exception raised by the threads is raised again from stop().
997        # Passing stop_grace_period_secs is for blocked enqueue/dequeue
998        # threads which are not checking for `should_stop()`.  They
999        # will be stopped when we close the session further down.
1000        self.stop(close_summary_writer=close_summary_writer)
1001      finally:
1002        # Close the session to finish up all pending calls.  We do not care
1003        # about exceptions raised when closing.  This takes care of
1004        # blocked enqueue/dequeue calls.
1005        try:
1006          sess.close()
1007        except Exception:
1008          # Silently ignore exceptions raised by close().
1009          pass
1010  # pylint: enable=g-doc-return-or-yield,broad-except
1011
1012
1013class SVSummaryThread(coordinator.LooperThread):
1014  """A thread to save summaries on a timer."""
1015
1016  def __init__(self, sv, sess):
1017    """Create a SVSummaryThread.
1018
1019    Args:
1020      sv: A `Supervisor`.
1021      sess: A `Session`.
1022    """
1023    super(SVSummaryThread, self).__init__(sv.coord, sv.save_summaries_secs)
1024    self._sv = sv
1025    self._sess = sess
1026
1027  def run_loop(self):
1028    if self._sv.global_step is not None:
1029      summary_strs, global_step = self._sess.run([self._sv.summary_op,
1030                                                  self._sv.global_step])
1031    else:
1032      summary_strs = self._sess.run(self._sv.summary_op)
1033      global_step = None
1034    if self._sv.summary_writer:
1035      logging.info("Recording summary at step %s.", global_step)
1036      self._sv.summary_writer.add_summary(summary_strs, global_step)
1037
1038
1039class SVStepCounterThread(coordinator.LooperThread):
1040  """Threads to count steps and measure their duration."""
1041
1042  def __init__(self, sv, sess, step_counter=None):
1043    """Create a `SVStepCounterThread`.
1044
1045    Args:
1046      sv: A `Supervisor`.
1047      sess: A `Session`.
1048      step_counter: A `Tensor` holding the step counter. By defaults, it uses
1049        sv.global_step.
1050    """
1051    super(SVStepCounterThread, self).__init__(sv.coord, sv.save_summaries_secs)
1052    self._sv = sv
1053    self._sess = sess
1054    self._last_time = 0.0
1055    self._last_step = 0
1056    step_counter = sv.global_step if step_counter is None else step_counter
1057    self._step_counter = step_counter
1058    self._summary_tag = "%s/sec" % self._step_counter.op.name
1059
1060  def start_loop(self):
1061    self._last_time = time.time()
1062    self._last_step = training_util.global_step(
1063        self._sess, self._step_counter)
1064
1065  def run_loop(self):
1066    # Count the steps.
1067    current_step = training_util.global_step(self._sess, self._step_counter)
1068    added_steps = current_step - self._last_step
1069    self._last_step = current_step
1070    # Measure the elapsed time.
1071    current_time = time.time()
1072    elapsed_time = current_time - self._last_time
1073    self._last_time = current_time
1074    # Reports the number of steps done per second
1075    if elapsed_time > 0.:
1076      steps_per_sec = added_steps / elapsed_time
1077    else:
1078      steps_per_sec = float("inf")
1079    summary = Summary(value=[Summary.Value(tag=self._summary_tag,
1080                                           simple_value=steps_per_sec)])
1081    if self._sv.summary_writer:
1082      self._sv.summary_writer.add_summary(summary, current_step)
1083    logging.log_first_n(logging.INFO, "%s: %g", 10,
1084                        self._summary_tag, steps_per_sec)
1085
1086
1087class SVTimerCheckpointThread(coordinator.LooperThread):
1088  """A thread to checkpoint on a timer."""
1089
1090  def __init__(self, sv, sess):
1091    """Create a `SVTimerCheckpointThread`.
1092
1093    Args:
1094      sv: A `Supervisor`.
1095      sess: A `Session`.
1096    """
1097    super(SVTimerCheckpointThread, self).__init__(sv.coord, sv.save_model_secs)
1098    self._sv = sv
1099    self._sess = sess
1100
1101  def run_loop(self):
1102    logging.info("Saving checkpoint to path %s", self._sv.save_path)
1103    self._sv.saver.save(self._sess, self._sv.save_path,
1104                        global_step=self._sv.global_step)
1105    if self._sv.summary_writer and self._sv.global_step is not None:
1106      current_step = training_util.global_step(self._sess, self._sv.global_step)
1107      self._sv.summary_writer.add_session_log(
1108          SessionLog(status=SessionLog.CHECKPOINT,
1109                     checkpoint_path=self._sv.save_path),
1110          current_step)
1111
1112
1113# TODO(sherrym): All non-PEP8 compliant names will be deprecated shortly.
1114setattr(Supervisor, "PrepareSession", Supervisor.prepare_or_wait_for_session)
1115setattr(Supervisor, "StartQueueRunners", Supervisor.start_queue_runners)
1116setattr(Supervisor, "StartStandardServices", Supervisor.start_standard_services)
1117setattr(Supervisor, "Stop", Supervisor.stop)
1118setattr(Supervisor, "RequestStop", Supervisor.request_stop)
1119setattr(Supervisor, "Loop", Supervisor.loop)
1120setattr(Supervisor, "ShouldStop", Supervisor.should_stop)
1121setattr(Supervisor, "StopOnException", Supervisor.stop_on_exception)
1122setattr(Supervisor, "WaitForStop", Supervisor.wait_for_stop)
1123setattr(Supervisor, "SummaryComputed", Supervisor.summary_computed)
1124