1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Base Estimator class."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import copy
23import os
24import tempfile
25
26import numpy as np
27import six
28
29from google.protobuf import message
30from tensorflow.core.framework import summary_pb2
31from tensorflow.core.protobuf import config_pb2
32from tensorflow.python.client import session as tf_session
33from tensorflow.python.data.ops import dataset_ops
34from tensorflow.python.eager import context
35from tensorflow.python.estimator import model_fn as model_fn_lib
36from tensorflow.python.estimator import run_config
37from tensorflow.python.estimator import util
38from tensorflow.python.estimator import warm_starting_util
39from tensorflow.python.estimator.export.export import build_all_signature_defs
40from tensorflow.python.estimator.export.export import get_temp_export_dir
41from tensorflow.python.estimator.export.export import get_timestamped_export_dir
42from tensorflow.python.framework import ops
43from tensorflow.python.framework import random_seed
44from tensorflow.python.ops import control_flow_ops
45from tensorflow.python.ops import metrics as metrics_lib
46from tensorflow.python.platform import gfile
47from tensorflow.python.platform import tf_logging as logging
48from tensorflow.python.saved_model import builder as saved_model_builder
49from tensorflow.python.saved_model import tag_constants
50from tensorflow.python.summary import summary
51from tensorflow.python.summary.writer import writer_cache
52from tensorflow.python.training import evaluation
53from tensorflow.python.training import monitored_session
54from tensorflow.python.training import saver
55from tensorflow.python.training import training
56from tensorflow.python.training import training_util
57from tensorflow.python.util import compat
58from tensorflow.python.util import compat_internal
59from tensorflow.python.util import nest
60from tensorflow.python.util.tf_export import tf_export
61
62
63_VALID_MODEL_FN_ARGS = set(
64    ['features', 'labels', 'mode', 'params', 'self', 'config'])
65
66
67@tf_export('estimator.Estimator')
68class Estimator(object):
69  """Estimator class to train and evaluate TensorFlow models.
70
71  The `Estimator` object wraps a model which is specified by a `model_fn`,
72  which, given inputs and a number of other parameters, returns the ops
73  necessary to perform training, evaluation, or predictions.
74
75  All outputs (checkpoints, event files, etc.) are written to `model_dir`, or a
76  subdirectory thereof. If `model_dir` is not set, a temporary directory is
77  used.
78
79  The `config` argument can be passed `RunConfig` object containing information
80  about the execution environment. It is passed on to the `model_fn`, if the
81  `model_fn` has a parameter named "config" (and input functions in the same
82  manner). If the `config` parameter is not passed, it is instantiated by the
83  `Estimator`. Not passing config means that defaults useful for local execution
84  are used. `Estimator` makes config available to the model (for instance, to
85  allow specialization based on the number of workers available), and also uses
86  some of its fields to control internals, especially regarding checkpointing.
87
88  The `params` argument contains hyperparameters. It is passed to the
89  `model_fn`, if the `model_fn` has a parameter named "params", and to the input
90  functions in the same manner. `Estimator` only passes params along, it does
91  not inspect it. The structure of `params` is therefore entirely up to the
92  developer.
93
94  None of `Estimator`'s methods can be overridden in subclasses (its
95  constructor enforces this). Subclasses should use `model_fn` to configure
96  the base class, and may add methods implementing specialized functionality.
97
98  @compatibility(eager)
99  Estimators are not compatible with eager execution.
100  @end_compatibility
101  """
102
103  def __init__(self, model_fn, model_dir=None, config=None, params=None,
104               warm_start_from=None):
105    """Constructs an `Estimator` instance.
106
107    See @{$estimators} for more information. To warm-start an `Estimator`:
108
109    ```python
110    estimator = tf.estimator.DNNClassifier(
111        feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb],
112        hidden_units=[1024, 512, 256],
113        warm_start_from="/path/to/checkpoint/dir")
114    ```
115
116    For more details on warm-start configuration, see
117    @{tf.estimator.WarmStartSettings$WarmStartSettings}.
118
119    Args:
120      model_fn: Model function. Follows the signature:
121
122        * Args:
123
124          * `features`: This is the first item returned from the `input_fn`
125                 passed to `train`, `evaluate`, and `predict`. This should be a
126                 single `Tensor` or `dict` of same.
127          * `labels`: This is the second item returned from the `input_fn`
128                 passed to `train`, `evaluate`, and `predict`. This should be a
129                 single `Tensor` or `dict` of same (for multi-head models). If
130                 mode is `ModeKeys.PREDICT`, `labels=None` will be passed. If
131                 the `model_fn`'s signature does not accept `mode`, the
132                 `model_fn` must still be able to handle `labels=None`.
133          * `mode`: Optional. Specifies if this training, evaluation or
134                 prediction. See `ModeKeys`.
135          * `params`: Optional `dict` of hyperparameters.  Will receive what
136                 is passed to Estimator in `params` parameter. This allows
137                 to configure Estimators from hyper parameter tuning.
138          * `config`: Optional configuration object. Will receive what is passed
139                 to Estimator in `config` parameter, or the default `config`.
140                 Allows updating things in your model_fn based on configuration
141                 such as `num_ps_replicas`, or `model_dir`.
142
143        * Returns:
144          `EstimatorSpec`
145
146      model_dir: Directory to save model parameters, graph and etc. This can
147        also be used to load checkpoints from the directory into a estimator to
148        continue training a previously saved model. If `PathLike` object, the
149        path will be resolved. If `None`, the model_dir in `config` will be used
150        if set. If both are set, they must be same. If both are `None`, a
151        temporary directory will be used.
152      config: Configuration object.
153      params: `dict` of hyper parameters that will be passed into `model_fn`.
154              Keys are names of parameters, values are basic python types.
155      warm_start_from: Optional string filepath to a checkpoint to warm-start
156                       from, or a `tf.estimator.WarmStartSettings` object to
157                       fully configure warm-starting.  If the string filepath is
158                       provided instead of a `WarmStartSettings`, then all
159                       variables are warm-started, and it is assumed that
160                       vocabularies and Tensor names are unchanged.
161
162    Raises:
163      RuntimeError: If eager execution is enabled.
164      ValueError: parameters of `model_fn` don't match `params`.
165      ValueError: if this is called via a subclass and if that class overrides
166        a member of `Estimator`.
167    """
168    if context.in_eager_mode():
169      raise RuntimeError(
170          'Estimators are not supported when eager execution is enabled.')
171
172    Estimator._assert_members_are_not_overridden(self)
173
174    if config is None:
175      self._config = run_config.RunConfig()
176      logging.info('Using default config.')
177    else:
178      if not isinstance(config, run_config.RunConfig):
179        raise ValueError(
180            'config must be an instance of RunConfig, but provided %s.' %
181            config)
182      self._config = config
183
184    # Model directory.
185    model_dir = compat_internal.path_to_str(model_dir)
186    if (model_dir is not None) and (self._config.model_dir is not None):
187      if model_dir != self._config.model_dir:
188        # TODO(alanyee): remove this suppression after it is no longer needed
189        # pylint: disable=g-doc-exception
190        raise ValueError(
191            "model_dir are set both in constructor and RunConfig, but with "
192            "different values. In constructor: '{}', in RunConfig: "
193            "'{}' ".format(model_dir, self._config.model_dir))
194        # pylint: enable=g-doc-exception
195
196    self._model_dir = model_dir or self._config.model_dir
197    if self._model_dir is None:
198      self._model_dir = tempfile.mkdtemp()
199      logging.warning('Using temporary folder as model directory: %s',
200                      self._model_dir)
201    if self._config.model_dir is None:
202      self._config = self._config.replace(model_dir=self._model_dir)
203    logging.info('Using config: %s', str(vars(self._config)))
204
205    if self._config.session_config is None:
206      self._session_config = config_pb2.ConfigProto(allow_soft_placement=True)
207    else:
208      self._session_config = self._config.session_config
209
210    self._device_fn = _get_replica_device_setter(self._config)
211
212    if model_fn is None:
213      raise ValueError('model_fn must be provided to Estimator.')
214    _verify_model_fn_args(model_fn, params)
215    self._model_fn = model_fn
216    self._params = copy.deepcopy(params or {})
217
218    # pylint: disable=protected-access
219    self._warm_start_settings = (
220        warm_starting_util._get_default_warm_start_settings(warm_start_from))
221    # pylint: enable=protected-access
222
223  @property
224  def model_dir(self):
225    return self._model_dir
226
227  @property
228  def config(self):
229    return copy.deepcopy(self._config)
230
231  @property
232  def params(self):
233    return copy.deepcopy(self._params)
234
235  @property
236  def model_fn(self):
237    """Returns the model_fn which is bound to self.params.
238
239    Returns:
240      The model_fn with following signature:
241        `def model_fn(features, labels, mode, config)`
242    """
243
244    def public_model_fn(features, labels, mode, config):
245      return self._call_model_fn(features, labels, mode, config)
246
247    return public_model_fn
248
249  # TODO(ispir): support a list of names
250  def get_variable_value(self, name):
251    """Returns value of the variable given by name.
252
253    Args:
254      name: string or a list of string, name of the tensor.
255
256    Returns:
257      Numpy array - value of the tensor.
258
259    Raises:
260      ValueError: If the Estimator has not produced a checkpoint yet.
261    """
262    _check_checkpoint_available(self.model_dir)
263    return training.load_variable(self.model_dir, name)
264
265  def get_variable_names(self):
266    """Returns list of all variable names in this model.
267
268    Returns:
269      List of names.
270
271    Raises:
272      ValueError: If the Estimator has not produced a checkpoint yet.
273    """
274    _check_checkpoint_available(self.model_dir)
275    return [name for name, _ in training.list_variables(self.model_dir)]
276
277  def latest_checkpoint(self):
278    """Finds the filename of latest saved checkpoint file in `model_dir`.
279
280    Returns:
281      The full path to the latest checkpoint or `None` if no checkpoint was
282      found.
283    """
284    return saver.latest_checkpoint(self.model_dir)
285
286  def train(self,
287            input_fn,
288            hooks=None,
289            steps=None,
290            max_steps=None,
291            saving_listeners=None):
292    """Trains a model given training data input_fn.
293
294    Args:
295      input_fn: A function that provides input data for training as minibatches.
296        See @{$get_started/premade_estimators#create_input_functions} for more
297        information. The function should construct and return one of
298        the following:
299
300          * A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a
301            tuple (features, labels) with same constraints as below.
302          * A tuple (features, labels): Where features is a `Tensor` or a
303            dictionary of string feature name to `Tensor` and labels is a
304            `Tensor` or a dictionary of string label name to `Tensor`. Both
305            features and labels are consumed by `model_fn`. They should satisfy
306            the expectation of `model_fn` from inputs.
307
308      hooks: List of `SessionRunHook` subclass instances. Used for callbacks
309        inside the training loop.
310      steps: Number of steps for which to train model. If `None`, train forever
311        or train until input_fn generates the `OutOfRange` error or
312        `StopIteration` exception. 'steps' works incrementally. If you call two
313        times train(steps=10) then training occurs in total 20 steps. If
314        `OutOfRange` or `StopIteration` occurs in the middle, training stops
315        before 20 steps. If you don't want to have incremental behavior please
316        set `max_steps` instead. If set, `max_steps` must be `None`.
317      max_steps: Number of total steps for which to train model. If `None`,
318        train forever or train until input_fn generates the `OutOfRange` error
319        or `StopIteration` exception. If set, `steps` must be `None`. If
320        `OutOfRange` or `StopIteration` occurs in the middle, training stops
321        before `max_steps` steps.
322        Two calls to `train(steps=100)` means 200 training
323        iterations. On the other hand, two calls to `train(max_steps=100)` means
324        that the second call will not do any iteration since first call did
325        all 100 steps.
326      saving_listeners: list of `CheckpointSaverListener` objects. Used for
327        callbacks that run immediately before or after checkpoint savings.
328
329    Returns:
330      `self`, for chaining.
331
332    Raises:
333      ValueError: If both `steps` and `max_steps` are not `None`.
334      ValueError: If either `steps` or `max_steps` is <= 0.
335    """
336    if (steps is not None) and (max_steps is not None):
337      raise ValueError('Can not provide both steps and max_steps.')
338    if steps is not None and steps <= 0:
339      raise ValueError('Must specify steps > 0, given: {}'.format(steps))
340    if max_steps is not None and max_steps <= 0:
341      raise ValueError(
342          'Must specify max_steps > 0, given: {}'.format(max_steps))
343
344    if max_steps is not None:
345      start_step = _load_global_step_from_checkpoint_dir(self._model_dir)
346      if max_steps <= start_step:
347        logging.info('Skipping training since max_steps has already saved.')
348        return self
349
350    hooks = _check_hooks_type(hooks)
351    hooks.extend(self._convert_train_steps_to_hooks(steps, max_steps))
352
353    saving_listeners = _check_listeners_type(saving_listeners)
354    loss = self._train_model(input_fn, hooks, saving_listeners)
355    logging.info('Loss for final step: %s.', loss)
356    return self
357
358  def _convert_train_steps_to_hooks(self, steps, max_steps):
359    if steps is not None or max_steps is not None:
360      return [training.StopAtStepHook(steps, max_steps)]
361    else:
362      return []
363
364  def evaluate(self, input_fn, steps=None, hooks=None, checkpoint_path=None,
365               name=None):
366    """Evaluates the model given evaluation data input_fn.
367
368    For each step, calls `input_fn`, which returns one batch of data.
369    Evaluates until:
370    - `steps` batches are processed, or
371    - `input_fn` raises an end-of-input exception (`OutOfRangeError` or
372    `StopIteration`).
373
374    Args:
375      input_fn: A function that constructs the input data for evaluation.
376        See @{$get_started/premade_estimators#create_input_functions} for more
377        information. The function should construct and return one of
378        the following:
379
380          * A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a
381            tuple (features, labels) with same constraints as below.
382          * A tuple (features, labels): Where features is a `Tensor` or a
383            dictionary of string feature name to `Tensor` and labels is a
384            `Tensor` or a dictionary of string label name to `Tensor`. Both
385            features and labels are consumed by `model_fn`. They should satisfy
386            the expectation of `model_fn` from inputs.
387
388      steps: Number of steps for which to evaluate model. If `None`, evaluates
389        until `input_fn` raises an end-of-input exception.
390      hooks: List of `SessionRunHook` subclass instances. Used for callbacks
391        inside the evaluation call.
392      checkpoint_path: Path of a specific checkpoint to evaluate. If `None`, the
393        latest checkpoint in `model_dir` is used.
394      name: Name of the evaluation if user needs to run multiple evaluations on
395        different data sets, such as on training data vs test data. Metrics for
396        different evaluations are saved in separate folders, and appear
397        separately in tensorboard.
398
399    Returns:
400      A dict containing the evaluation metrics specified in `model_fn` keyed by
401      name, as well as an entry `global_step` which contains the value of the
402      global step for which this evaluation was performed.
403
404    Raises:
405      ValueError: If `steps <= 0`.
406      ValueError: If no model has been trained, namely `model_dir`, or the
407        given `checkpoint_path` is empty.
408    """
409    hooks = _check_hooks_type(hooks)
410    hooks.extend(self._convert_eval_steps_to_hooks(steps))
411
412    return self._evaluate_model(
413        input_fn=input_fn,
414        hooks=hooks,
415        checkpoint_path=checkpoint_path,
416        name=name)
417
418  def _convert_eval_steps_to_hooks(self, steps):
419    if steps is None:
420      return []
421
422    if steps <= 0:
423      raise ValueError('Must specify steps > 0, given: {}'.format(steps))
424    return [evaluation._StopAfterNEvalsHook(num_evals=steps)]  # pylint: disable=protected-access
425
426  def predict(self,
427              input_fn,
428              predict_keys=None,
429              hooks=None,
430              checkpoint_path=None,
431              yield_single_examples=True):
432    """Yields predictions for given features.
433
434    Args:
435      input_fn: A function that constructs the features. Prediction continues
436        until `input_fn` raises an end-of-input exception (`OutOfRangeError` or
437        `StopIteration`).
438        See @{$get_started/premade_estimators#create_input_functions} for more
439        information. The function should construct and return one of
440        the following:
441
442          * A 'tf.data.Dataset' object: Outputs of `Dataset` object must have
443            same constraints as below.
444          * features: A `Tensor` or a dictionary of string feature name to
445            `Tensor`. features are consumed by `model_fn`. They should satisfy
446            the expectation of `model_fn` from inputs.
447          * A tuple, in which case the first item is extracted as features.
448
449      predict_keys: list of `str`, name of the keys to predict. It is used if
450        the `EstimatorSpec.predictions` is a `dict`. If `predict_keys` is used
451        then rest of the predictions will be filtered from the dictionary. If
452        `None`, returns all.
453      hooks: List of `SessionRunHook` subclass instances. Used for callbacks
454        inside the prediction call.
455      checkpoint_path: Path of a specific checkpoint to predict. If `None`, the
456        latest checkpoint in `model_dir` is used.
457      yield_single_examples: If False, yield the whole batch as returned by the
458        model_fn instead of decomposing the batch into individual elements. This
459        is useful if model_fn return some tensor with first dimension not
460        equal to the batch size
461
462    Yields:
463      Evaluated values of `predictions` tensors.
464
465    Raises:
466      ValueError: Could not find a trained model in model_dir.
467      ValueError: if batch length of predictions are not same and
468        yield_single_examples is True.
469      ValueError: If there is a conflict between `predict_keys` and
470        `predictions`. For example if `predict_keys` is not `None` but
471        `EstimatorSpec.predictions` is not a `dict`.
472    """
473    hooks = _check_hooks_type(hooks)
474    # Check that model has been trained.
475    if not checkpoint_path:
476      checkpoint_path = saver.latest_checkpoint(self._model_dir)
477    if not checkpoint_path:
478      raise ValueError('Could not find trained model in model_dir: {}.'.format(
479          self._model_dir))
480
481    with ops.Graph().as_default() as g:
482      random_seed.set_random_seed(self._config.tf_random_seed)
483      self._create_and_assert_global_step(g)
484      features, input_hooks = self._get_features_from_input_fn(
485          input_fn, model_fn_lib.ModeKeys.PREDICT)
486      estimator_spec = self._call_model_fn(
487          features, None, model_fn_lib.ModeKeys.PREDICT, self.config)
488      predictions = self._extract_keys(estimator_spec.predictions, predict_keys)
489      all_hooks = list(input_hooks)
490      all_hooks.extend(hooks)
491      all_hooks.extend(list(estimator_spec.prediction_hooks or []))
492      with training.MonitoredSession(
493          session_creator=training.ChiefSessionCreator(
494              checkpoint_filename_with_path=checkpoint_path,
495              master=self._config.master,
496              scaffold=estimator_spec.scaffold,
497              config=self._session_config),
498          hooks=all_hooks) as mon_sess:
499        while not mon_sess.should_stop():
500          preds_evaluated = mon_sess.run(predictions)
501          if not yield_single_examples:
502            yield preds_evaluated
503          elif not isinstance(predictions, dict):
504            for pred in preds_evaluated:
505              yield pred
506          else:
507            for i in range(self._extract_batch_length(preds_evaluated)):
508              yield {
509                  key: value[i]
510                  for key, value in six.iteritems(preds_evaluated)
511              }
512
513  def _assert_members_are_not_overridden(self):
514    """Asserts members of `Estimator` are not overridden."""
515    allowed_overrides = set([
516        '_call_input_fn', '_create_global_step',
517        '_convert_train_steps_to_hooks', '_convert_eval_steps_to_hooks',
518        '_tf_api_names'
519    ])
520    estimator_members = set([m for m in Estimator.__dict__.keys()
521                             if not m.startswith('__')])
522    subclass_members = set(self.__class__.__dict__.keys())
523    common_members = estimator_members & subclass_members - allowed_overrides
524    overridden_members = [
525        m for m in common_members
526        if Estimator.__dict__[m] != self.__class__.__dict__[m]]
527    if overridden_members:
528      raise ValueError(
529          'Subclasses of Estimator cannot override members of Estimator. '
530          '{} does override {}'.format(self.__class__, overridden_members))
531
532  def export_savedmodel(
533      self, export_dir_base, serving_input_receiver_fn,
534      assets_extra=None,
535      as_text=False,
536      checkpoint_path=None,
537      strip_default_attrs=False):
538    # pylint: disable=line-too-long
539    """Exports inference graph as a SavedModel into given dir.
540
541    For a detailed guide, see
542    @{$saved_model#using_savedmodel_with_estimators$Using SavedModel with Estimators}.
543
544    This method builds a new graph by first calling the
545    serving_input_receiver_fn to obtain feature `Tensor`s, and then calling
546    this `Estimator`'s model_fn to generate the model graph based on those
547    features. It restores the given checkpoint (or, lacking that, the most
548    recent checkpoint) into this graph in a fresh session.  Finally it creates
549    a timestamped export directory below the given export_dir_base, and writes
550    a `SavedModel` into it containing a single `MetaGraphDef` saved from this
551    session.
552
553    The exported `MetaGraphDef` will provide one `SignatureDef` for each
554    element of the export_outputs dict returned from the model_fn, named using
555    the same keys.  One of these keys is always
556    signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, indicating which
557    signature will be served when a serving request does not specify one.
558    For each signature, the outputs are provided by the corresponding
559    `ExportOutput`s, and the inputs are always the input receivers provided by
560    the serving_input_receiver_fn.
561
562    Extra assets may be written into the SavedModel via the assets_extra
563    argument.  This should be a dict, where each key gives a destination path
564    (including the filename) relative to the assets.extra directory.  The
565    corresponding value gives the full path of the source file to be copied.
566    For example, the simple case of copying a single file without renaming it
567    is specified as `{'my_asset_file.txt': '/path/to/my_asset_file.txt'}`.
568
569    Args:
570      export_dir_base: A string containing a directory in which to create
571        timestamped subdirectories containing exported SavedModels.
572      serving_input_receiver_fn: A function that takes no argument and
573        returns a `ServingInputReceiver`.
574      assets_extra: A dict specifying how to populate the assets.extra directory
575        within the exported SavedModel, or `None` if no extra assets are needed.
576      as_text: whether to write the SavedModel proto in text format.
577      checkpoint_path: The checkpoint path to export.  If `None` (the default),
578        the most recent checkpoint found within the model directory is chosen.
579      strip_default_attrs: Boolean. If `True`, default-valued attributes will be
580        removed from the NodeDefs. For a detailed guide, see
581        [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
582
583    Returns:
584      The string path to the exported directory.
585
586    Raises:
587      ValueError: if no serving_input_receiver_fn is provided, no export_outputs
588          are provided, or no checkpoint can be found.
589    """
590    # pylint: enable=line-too-long
591    if serving_input_receiver_fn is None:
592      raise ValueError('serving_input_receiver_fn must be defined.')
593
594    with ops.Graph().as_default() as g:
595      self._create_and_assert_global_step(g)
596      random_seed.set_random_seed(self._config.tf_random_seed)
597      serving_input_receiver = serving_input_receiver_fn()
598
599      # Call the model_fn and collect the export_outputs.
600      estimator_spec = self._call_model_fn(
601          features=serving_input_receiver.features,
602          labels=None,
603          mode=model_fn_lib.ModeKeys.PREDICT,
604          config=self.config)
605
606      # Build the SignatureDefs from receivers and all outputs
607      signature_def_map = build_all_signature_defs(
608          serving_input_receiver.receiver_tensors,
609          estimator_spec.export_outputs,
610          serving_input_receiver.receiver_tensors_alternatives)
611
612      if not checkpoint_path:
613        # Locate the latest checkpoint
614        checkpoint_path = saver.latest_checkpoint(self._model_dir)
615      if not checkpoint_path:
616        raise ValueError("Couldn't find trained model at %s." % self._model_dir)
617
618      export_dir = get_timestamped_export_dir(export_dir_base)
619      temp_export_dir = get_temp_export_dir(export_dir)
620
621      # TODO(soergel): Consider whether MonitoredSession makes sense here
622      with tf_session.Session(config=self._session_config) as session:
623
624        saver_for_restore = estimator_spec.scaffold.saver or saver.Saver(
625            sharded=True)
626        saver_for_restore.restore(session, checkpoint_path)
627
628        # pylint: disable=protected-access
629        local_init_op = (
630            estimator_spec.scaffold.local_init_op or
631            monitored_session.Scaffold._default_local_init_op())
632        # pylint: enable=protected-access
633
634        # Perform the export
635        builder = saved_model_builder.SavedModelBuilder(temp_export_dir)
636        builder.add_meta_graph_and_variables(
637            session, [tag_constants.SERVING],
638            signature_def_map=signature_def_map,
639            assets_collection=ops.get_collection(
640                ops.GraphKeys.ASSET_FILEPATHS),
641            legacy_init_op=local_init_op,
642            strip_default_attrs=strip_default_attrs)
643        builder.save(as_text)
644
645      # Add the extra assets
646      if assets_extra:
647        assets_extra_path = os.path.join(compat.as_bytes(temp_export_dir),
648                                         compat.as_bytes('assets.extra'))
649        for dest_relative, source in assets_extra.items():
650          dest_absolute = os.path.join(compat.as_bytes(assets_extra_path),
651                                       compat.as_bytes(dest_relative))
652          dest_path = os.path.dirname(dest_absolute)
653          gfile.MakeDirs(dest_path)
654          gfile.Copy(source, dest_absolute)
655
656      gfile.Rename(temp_export_dir, export_dir)
657      return export_dir
658
659  def _get_features_from_input_fn(self, input_fn, mode):
660    """Extracts the `features` from return values of `input_fn`."""
661    result = self._call_input_fn(input_fn, mode)
662    input_hooks = []
663    if isinstance(result, dataset_ops.Dataset):
664      iterator = result.make_initializable_iterator()
665      input_hooks.append(_DatasetInitializerHook(iterator))
666      result = iterator.get_next()
667    if isinstance(result, (list, tuple)):
668      # Unconditionally drop the label (the second element of result).
669      result = result[0]
670
671    if not _has_dataset_or_queue_runner(result):
672      logging.warning('Input graph does not use tf.data.Dataset or contain a '
673                      'QueueRunner. That means predict yields forever. '
674                      'This is probably a mistake.')
675    return result, input_hooks
676
677  def _get_features_and_labels_from_input_fn(self, input_fn, mode):
678    """Extracts the `features` and labels from return values of `input_fn`."""
679    result = self._call_input_fn(input_fn, mode)
680    input_hooks = []
681    if isinstance(result, dataset_ops.Dataset):
682      iterator = result.make_initializable_iterator()
683      input_hooks.append(_DatasetInitializerHook(iterator))
684      result = iterator.get_next()
685    if isinstance(result, (list, tuple)):
686      if len(result) != 2:
687        raise ValueError(
688            'input_fn should return (features, labels) as a len 2 tuple.')
689      return result[0], result[1], input_hooks
690    return result, None, input_hooks
691
692  def _extract_batch_length(self, preds_evaluated):
693    """Extracts batch length of predictions."""
694    batch_length = None
695    for key, value in six.iteritems(preds_evaluated):
696      batch_length = batch_length or value.shape[0]
697      if value.shape[0] != batch_length:
698        raise ValueError('Batch length of predictions should be same. %s has '
699                         'different batch length then others.' % key)
700    return batch_length
701
702  def _extract_keys(self, predictions, predict_keys):
703    """Extracts `predict_keys` from `predictions`."""
704    if not predict_keys:
705      return predictions
706    if not isinstance(predictions, dict):
707      raise ValueError(
708          'predict_keys argument is not valid in case of non-dict predictions.')
709    existing_keys = predictions.keys()
710    predictions = {
711        key: value
712        for key, value in six.iteritems(predictions) if key in predict_keys
713    }
714    if not predictions:
715      raise ValueError('Expected to run at least one output from %s, '
716                       'provided %s.' % (existing_keys, predict_keys))
717    return predictions
718
719  def _create_global_step(self, graph):
720    """Creates the global step tensor in graph.
721
722    The global step tensor must be an integer type with name 'global_step' and
723    be added to the collection ${tf.GraphKeys.GLOBAL_STEP}.
724
725    Args:
726      graph: The graph in which to create the global step tensor.
727
728    Returns:
729      The global step `Tensor`.
730    """
731    return training.create_global_step(graph)
732
733  def _create_and_assert_global_step(self, graph):
734    """Creates and asserts properties of the global step.
735
736    Args:
737      graph: The graph in which to create the global step tensor.
738
739    Returns:
740      The global step `Tensor`.
741    """
742    step = self._create_global_step(graph)
743    assert step == training.get_global_step()
744    assert step.dtype.is_integer
745    return step
746
747  def _call_input_fn(self, input_fn, mode):
748    """Calls the input function.
749
750    Args:
751      input_fn: The input function.
752      mode: ModeKeys
753
754    Returns:
755      Either features or (features, labels) where features and labels are:
756        features - `Tensor` or dictionary of string feature name to `Tensor`.
757        labels - `Tensor` or dictionary of `Tensor` with labels.
758
759    Raises:
760      ValueError: if input_fn takes invalid arguments.
761    """
762    input_fn_args = util.fn_args(input_fn)
763    kwargs = {}
764    if 'mode' in input_fn_args:
765      kwargs['mode'] = mode
766    if 'params' in input_fn_args:
767      kwargs['params'] = self.params
768    if 'config' in input_fn_args:
769      kwargs['config'] = self.config
770    with ops.device('/cpu:0'):
771      return input_fn(**kwargs)
772
773  def _call_model_fn(self, features, labels, mode, config):
774    """Calls model function.
775
776    Args:
777      features: features dict.
778      labels: labels dict.
779      mode: ModeKeys
780      config: RunConfig
781
782    Returns:
783      An `EstimatorSpec` object.
784
785    Raises:
786      ValueError: if model_fn returns invalid objects.
787    """
788    model_fn_args = util.fn_args(self._model_fn)
789    kwargs = {}
790    if 'labels' in model_fn_args:
791      kwargs['labels'] = labels
792    else:
793      if labels is not None:
794        raise ValueError(
795            'model_fn does not take labels, but input_fn returns labels.')
796    if 'mode' in model_fn_args:
797      kwargs['mode'] = mode
798    if 'params' in model_fn_args:
799      kwargs['params'] = self.params
800    if 'config' in model_fn_args:
801      kwargs['config'] = config
802
803    logging.info('Calling model_fn.')
804    model_fn_results = self._model_fn(features=features, **kwargs)
805    logging.info('Done calling model_fn.')
806
807    if not isinstance(model_fn_results, model_fn_lib.EstimatorSpec):
808      raise ValueError('model_fn should return an EstimatorSpec.')
809
810    return model_fn_results
811
812  def _train_model(self, input_fn, hooks, saving_listeners):
813    worker_hooks = []
814    with ops.Graph().as_default() as g, g.device(self._device_fn):
815      random_seed.set_random_seed(self._config.tf_random_seed)
816      global_step_tensor = self._create_and_assert_global_step(g)
817      training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
818      features, labels, input_hooks = (
819          self._get_features_and_labels_from_input_fn(
820              input_fn, model_fn_lib.ModeKeys.TRAIN))
821      worker_hooks.extend(input_hooks)
822      estimator_spec = self._call_model_fn(
823          features, labels, model_fn_lib.ModeKeys.TRAIN, self.config)
824
825      if self._warm_start_settings:
826        logging.info('Warm-starting with WarmStartSettings: %s' %
827                     (self._warm_start_settings,))
828        # pylint: disable=protected-access
829        warm_starting_util._warm_start(self._warm_start_settings)
830        # pylint: enable=protected-access
831      # Check if the user created a loss summary, and add one if they didn't.
832      # We assume here that the summary is called 'loss'. If it is not, we will
833      # make another one with the name 'loss' to ensure it shows up in the right
834      # graph in TensorBoard.
835      if not any([x.op.name == 'loss'
836                  for x in ops.get_collection(ops.GraphKeys.SUMMARIES)]):
837        summary.scalar('loss', estimator_spec.loss)
838      ops.add_to_collection(ops.GraphKeys.LOSSES, estimator_spec.loss)
839      worker_hooks.extend(hooks)
840      worker_hooks.extend([
841          training.NanTensorHook(estimator_spec.loss),
842          training.LoggingTensorHook(
843              {
844                  'loss': estimator_spec.loss,
845                  'step': global_step_tensor
846              },
847              every_n_iter=100)
848      ])
849      worker_hooks.extend(estimator_spec.training_hooks)
850
851      if not (estimator_spec.scaffold.saver or
852              ops.get_collection(ops.GraphKeys.SAVERS)):
853        ops.add_to_collection(
854            ops.GraphKeys.SAVERS,
855            training.Saver(
856                sharded=True,
857                max_to_keep=self._config.keep_checkpoint_max,
858                keep_checkpoint_every_n_hours=(
859                    self._config.keep_checkpoint_every_n_hours),
860                defer_build=True,
861                save_relative_paths=True))
862
863      chief_hooks = []
864      all_hooks = worker_hooks + list(estimator_spec.training_chief_hooks)
865      saver_hooks = [
866          h for h in all_hooks if isinstance(h, training.CheckpointSaverHook)]
867      if (self._config.save_checkpoints_secs or
868          self._config.save_checkpoints_steps):
869        if not saver_hooks:
870          chief_hooks = [
871              training.CheckpointSaverHook(
872                  self._model_dir,
873                  save_secs=self._config.save_checkpoints_secs,
874                  save_steps=self._config.save_checkpoints_steps,
875                  scaffold=estimator_spec.scaffold)
876          ]
877          saver_hooks = [chief_hooks[0]]
878      if saving_listeners:
879        if not saver_hooks:
880          raise ValueError(
881              'There should be a CheckpointSaverHook to use saving_listeners. '
882              'Please set one of the RunConfig.save_checkpoints_steps or '
883              'RunConfig.save_checkpoints_secs.')
884        else:
885          # It is expected to have one CheckpointSaverHook. If multiple, we pick
886          # up the first one to add listener.
887          saver_hooks[0]._listeners.extend(saving_listeners)  # pylint: disable=protected-access
888      with training.MonitoredTrainingSession(
889          master=self._config.master,
890          is_chief=self._config.is_chief,
891          checkpoint_dir=self._model_dir,
892          scaffold=estimator_spec.scaffold,
893          hooks=worker_hooks,
894          chief_only_hooks=(
895              tuple(chief_hooks) + tuple(estimator_spec.training_chief_hooks)),
896          save_checkpoint_secs=0,  # Saving is handled by a hook.
897          save_summaries_steps=self._config.save_summary_steps,
898          config=self._session_config,
899          log_step_count_steps=self._config.log_step_count_steps) as mon_sess:
900        loss = None
901        while not mon_sess.should_stop():
902          _, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss])
903      return loss
904
905  def _evaluate_model(self,
906                      input_fn,
907                      hooks=None,
908                      checkpoint_path=None,
909                      name=''):
910    """Evaluates the model using the training.evaluation library."""
911    # Check that model has been trained (if nothing has been set explicitly).
912    if not checkpoint_path:
913      latest_path = saver.latest_checkpoint(self._model_dir)
914      if not latest_path:
915        raise ValueError('Could not find trained model in model_dir: {}.'.
916                         format(self._model_dir))
917      checkpoint_path = latest_path
918
919    # Setup output directory.
920    eval_dir = os.path.join(self._model_dir, 'eval' if not name else
921                            'eval_' + name)
922
923    with ops.Graph().as_default() as g:
924      random_seed.set_random_seed(self._config.tf_random_seed)
925      global_step_tensor = self._create_and_assert_global_step(g)
926      features, labels, input_hooks = (
927          self._get_features_and_labels_from_input_fn(
928              input_fn, model_fn_lib.ModeKeys.EVAL))
929      estimator_spec = self._call_model_fn(
930          features, labels, model_fn_lib.ModeKeys.EVAL, self.config)
931
932      if model_fn_lib.LOSS_METRIC_KEY in estimator_spec.eval_metric_ops:
933        raise ValueError(
934            'Metric with name "%s" is not allowed, because Estimator ' % (
935                model_fn_lib.LOSS_METRIC_KEY) +
936            'already defines a default metric with the same name.')
937      estimator_spec.eval_metric_ops[
938          model_fn_lib.LOSS_METRIC_KEY] = metrics_lib.mean(estimator_spec.loss)
939
940      update_op, eval_dict = _extract_metric_update_ops(
941          estimator_spec.eval_metric_ops)
942
943      if ops.GraphKeys.GLOBAL_STEP in eval_dict:
944        raise ValueError(
945            'Metric with name `global_step` is not allowed, because Estimator '
946            'already defines a default metric with the same name.')
947      eval_dict[ops.GraphKeys.GLOBAL_STEP] = global_step_tensor
948
949      all_hooks = list(input_hooks)
950      all_hooks.extend(hooks)
951      all_hooks.extend(list(estimator_spec.evaluation_hooks or []))
952
953      eval_results = evaluation._evaluate_once(  # pylint: disable=protected-access
954          checkpoint_path=checkpoint_path,
955          master=self._config.evaluation_master,
956          scaffold=estimator_spec.scaffold,
957          eval_ops=update_op,
958          final_ops=eval_dict,
959          hooks=all_hooks,
960          config=self._session_config)
961
962      _write_dict_to_summary(
963          output_dir=eval_dir,
964          dictionary=eval_results,
965          current_global_step=eval_results[ops.GraphKeys.GLOBAL_STEP])
966
967    return eval_results
968
969
970def _check_checkpoint_available(model_dir):
971  latest_path = saver.latest_checkpoint(model_dir)
972  if not latest_path:
973    raise ValueError(
974        'Could not find trained model in model_dir: {}.'.format(model_dir))
975
976
977def _check_hooks_type(hooks):
978  """Returns hooks if all are SessionRunHook, raises TypeError otherwise."""
979  hooks = list(hooks or [])
980  for h in hooks:
981    if not isinstance(h, training.SessionRunHook):
982      raise TypeError('Hooks must be a SessionRunHook, given: {}'.format(h))
983  return hooks
984
985
986def _check_listeners_type(saving_listeners):
987  """Check listeners type."""
988  listeners = list(saving_listeners or [])
989  for l in listeners:
990    if not isinstance(l, training.CheckpointSaverListener):
991      raise TypeError(
992          'saving_listeners must be a list of CheckpointSaverListener, '
993          'given: {}'.format(l))
994  return listeners
995
996
997def _get_replica_device_setter(config):
998  """Creates a replica device setter if required as a default device_fn.
999
1000  `Estimator` uses ReplicaDeviceSetter as a default device placer. It sets the
1001  distributed related arguments such as number of ps_replicas based on given
1002  config.
1003
1004  Args:
1005    config: A `RunConfig` instance.
1006
1007  Returns:
1008    A replica device setter, or None.
1009  """
1010  ps_ops = [
1011      'Variable', 'VariableV2', 'AutoReloadVariable', 'MutableHashTable',
1012      'MutableHashTableV2', 'MutableHashTableOfTensors',
1013      'MutableHashTableOfTensorsV2', 'MutableDenseHashTable',
1014      'MutableDenseHashTableV2', 'VarHandleOp'
1015  ]
1016
1017  if config.task_type:
1018    worker_device = '/job:%s/task:%d' % (config.task_type, config.task_id)
1019  else:
1020    worker_device = '/job:worker'
1021
1022  if config.num_ps_replicas > 0:
1023    return training.replica_device_setter(
1024        ps_tasks=config.num_ps_replicas,
1025        worker_device=worker_device,
1026        merge_devices=True,
1027        ps_ops=ps_ops,
1028        cluster=config.cluster_spec)
1029  else:
1030    return None
1031
1032
1033def _verify_model_fn_args(model_fn, params):
1034  """Verifies model fn arguments."""
1035  args = set(util.fn_args(model_fn))
1036  if 'features' not in args:
1037    raise ValueError('model_fn (%s) must include features argument.' % model_fn)
1038  if params is not None and 'params' not in args:
1039    raise ValueError('model_fn (%s) does not include params argument, '
1040                     'but params (%s) is passed to Estimator.' % (model_fn,
1041                                                                  params))
1042  if params is None and 'params' in args:
1043    logging.warning('Estimator\'s model_fn (%s) includes params '
1044                    'argument, but params are not passed to Estimator.',
1045                    model_fn)
1046  non_valid_args = list(args - _VALID_MODEL_FN_ARGS)
1047  if non_valid_args:
1048    raise ValueError('model_fn (%s) has following not expected args: %s' %
1049                     (model_fn, non_valid_args))
1050
1051
1052def _load_global_step_from_checkpoint_dir(checkpoint_dir):
1053  try:
1054    checkpoint_reader = training.NewCheckpointReader(
1055        training.latest_checkpoint(checkpoint_dir))
1056    return checkpoint_reader.get_tensor(ops.GraphKeys.GLOBAL_STEP)
1057  except:  # pylint: disable=bare-except
1058    return 0
1059
1060
1061def _extract_metric_update_ops(eval_dict):
1062  """Separate update operations from metric value operations."""
1063  update_ops = []
1064  value_ops = {}
1065  # Sort metrics lexicographically so graph is identical every time.
1066  for name, metric_ops in sorted(six.iteritems(eval_dict)):
1067    value_ops[name] = metric_ops[0]
1068    update_ops.append(metric_ops[1])
1069
1070  if update_ops:
1071    update_op = control_flow_ops.group(*update_ops)
1072  else:
1073    update_op = None
1074
1075  return update_op, value_ops
1076
1077
1078def _dict_to_str(dictionary):
1079  """Get a `str` representation of a `dict`.
1080
1081  Args:
1082    dictionary: The `dict` to be represented as `str`.
1083
1084  Returns:
1085    A `str` representing the `dictionary`.
1086  """
1087  return ', '.join('%s = %s' % (k, v)
1088                   for k, v in sorted(six.iteritems(dictionary)))
1089
1090
1091def _write_dict_to_summary(output_dir,
1092                           dictionary,
1093                           current_global_step):
1094  """Writes a `dict` into summary file in given output directory.
1095
1096  Args:
1097    output_dir: `str`, directory to write the summary file in.
1098    dictionary: the `dict` to be written to summary file.
1099    current_global_step: `int`, the current global step.
1100  """
1101  logging.info('Saving dict for global step %d: %s', current_global_step,
1102               _dict_to_str(dictionary))
1103  summary_writer = writer_cache.FileWriterCache.get(output_dir)
1104  summary_proto = summary_pb2.Summary()
1105  for key in dictionary:
1106    if dictionary[key] is None:
1107      continue
1108    if key == 'global_step':
1109      continue
1110    if (isinstance(dictionary[key], np.float32) or
1111        isinstance(dictionary[key], float)):
1112      summary_proto.value.add(tag=key, simple_value=float(dictionary[key]))
1113    elif (isinstance(dictionary[key], np.int64) or
1114          isinstance(dictionary[key], np.int32) or
1115          isinstance(dictionary[key], int)):
1116      summary_proto.value.add(tag=key, simple_value=int(dictionary[key]))
1117    elif isinstance(dictionary[key], six.binary_type):
1118      try:
1119        summ = summary_pb2.Summary.FromString(dictionary[key])
1120        for i, _ in enumerate(summ.value):
1121          summ.value[i].tag = key
1122        summary_proto.value.extend(summ.value)
1123      except message.DecodeError:
1124        logging.warn('Skipping summary for %s, cannot parse string to Summary.',
1125                     key)
1126        continue
1127    else:
1128      logging.warn(
1129          'Skipping summary for %s, must be a float, np.float32, np.int64, '
1130          'np.int32 or int or a serialized string of Summary.', key)
1131  summary_writer.add_summary(summary_proto, current_global_step)
1132  summary_writer.flush()
1133
1134
1135def _has_dataset_or_queue_runner(maybe_tensor):
1136  """Returns True if TF dataset or QueueRunner has been used."""
1137  # Check TF dataset first. Here, we use a simple algorithm to check the top
1138  # level Tensors only, which should be sufficient for most users.
1139  tensors = [x for x in nest.flatten(maybe_tensor) if isinstance(x, ops.Tensor)]
1140  if any([t.op.type == 'IteratorGetNext' for t in tensors]):
1141    return True
1142
1143  # Now, check queue.
1144  return ops.get_default_graph().get_collection(ops.GraphKeys.QUEUE_RUNNERS)
1145
1146
1147class _DatasetInitializerHook(training.SessionRunHook):
1148
1149  def __init__(self, iterator):
1150    self._iterator = iterator
1151
1152  def begin(self):
1153    self._initializer = self._iterator.initializer
1154
1155  def after_create_session(self, session, coord):
1156    del coord
1157    session.run(self._initializer)
1158