1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Base Estimator class.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import copy 23import os 24import tempfile 25 26import numpy as np 27import six 28 29from google.protobuf import message 30from tensorflow.core.framework import summary_pb2 31from tensorflow.core.protobuf import config_pb2 32from tensorflow.python.client import session as tf_session 33from tensorflow.python.data.ops import dataset_ops 34from tensorflow.python.eager import context 35from tensorflow.python.estimator import model_fn as model_fn_lib 36from tensorflow.python.estimator import run_config 37from tensorflow.python.estimator import util 38from tensorflow.python.estimator import warm_starting_util 39from tensorflow.python.estimator.export.export import build_all_signature_defs 40from tensorflow.python.estimator.export.export import get_temp_export_dir 41from tensorflow.python.estimator.export.export import get_timestamped_export_dir 42from tensorflow.python.framework import ops 43from tensorflow.python.framework import random_seed 44from tensorflow.python.ops import control_flow_ops 45from tensorflow.python.ops import metrics as metrics_lib 46from tensorflow.python.platform import gfile 47from tensorflow.python.platform import tf_logging as logging 48from tensorflow.python.saved_model import builder as saved_model_builder 49from tensorflow.python.saved_model import tag_constants 50from tensorflow.python.summary import summary 51from tensorflow.python.summary.writer import writer_cache 52from tensorflow.python.training import evaluation 53from tensorflow.python.training import monitored_session 54from tensorflow.python.training import saver 55from tensorflow.python.training import training 56from tensorflow.python.training import training_util 57from tensorflow.python.util import compat 58from tensorflow.python.util import compat_internal 59from tensorflow.python.util import nest 60from tensorflow.python.util.tf_export import tf_export 61 62 63_VALID_MODEL_FN_ARGS = set( 64 ['features', 'labels', 'mode', 'params', 'self', 'config']) 65 66 67@tf_export('estimator.Estimator') 68class Estimator(object): 69 """Estimator class to train and evaluate TensorFlow models. 70 71 The `Estimator` object wraps a model which is specified by a `model_fn`, 72 which, given inputs and a number of other parameters, returns the ops 73 necessary to perform training, evaluation, or predictions. 74 75 All outputs (checkpoints, event files, etc.) are written to `model_dir`, or a 76 subdirectory thereof. If `model_dir` is not set, a temporary directory is 77 used. 78 79 The `config` argument can be passed `RunConfig` object containing information 80 about the execution environment. It is passed on to the `model_fn`, if the 81 `model_fn` has a parameter named "config" (and input functions in the same 82 manner). If the `config` parameter is not passed, it is instantiated by the 83 `Estimator`. Not passing config means that defaults useful for local execution 84 are used. `Estimator` makes config available to the model (for instance, to 85 allow specialization based on the number of workers available), and also uses 86 some of its fields to control internals, especially regarding checkpointing. 87 88 The `params` argument contains hyperparameters. It is passed to the 89 `model_fn`, if the `model_fn` has a parameter named "params", and to the input 90 functions in the same manner. `Estimator` only passes params along, it does 91 not inspect it. The structure of `params` is therefore entirely up to the 92 developer. 93 94 None of `Estimator`'s methods can be overridden in subclasses (its 95 constructor enforces this). Subclasses should use `model_fn` to configure 96 the base class, and may add methods implementing specialized functionality. 97 98 @compatibility(eager) 99 Estimators are not compatible with eager execution. 100 @end_compatibility 101 """ 102 103 def __init__(self, model_fn, model_dir=None, config=None, params=None, 104 warm_start_from=None): 105 """Constructs an `Estimator` instance. 106 107 See @{$estimators} for more information. To warm-start an `Estimator`: 108 109 ```python 110 estimator = tf.estimator.DNNClassifier( 111 feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb], 112 hidden_units=[1024, 512, 256], 113 warm_start_from="/path/to/checkpoint/dir") 114 ``` 115 116 For more details on warm-start configuration, see 117 @{tf.estimator.WarmStartSettings$WarmStartSettings}. 118 119 Args: 120 model_fn: Model function. Follows the signature: 121 122 * Args: 123 124 * `features`: This is the first item returned from the `input_fn` 125 passed to `train`, `evaluate`, and `predict`. This should be a 126 single `Tensor` or `dict` of same. 127 * `labels`: This is the second item returned from the `input_fn` 128 passed to `train`, `evaluate`, and `predict`. This should be a 129 single `Tensor` or `dict` of same (for multi-head models). If 130 mode is `ModeKeys.PREDICT`, `labels=None` will be passed. If 131 the `model_fn`'s signature does not accept `mode`, the 132 `model_fn` must still be able to handle `labels=None`. 133 * `mode`: Optional. Specifies if this training, evaluation or 134 prediction. See `ModeKeys`. 135 * `params`: Optional `dict` of hyperparameters. Will receive what 136 is passed to Estimator in `params` parameter. This allows 137 to configure Estimators from hyper parameter tuning. 138 * `config`: Optional configuration object. Will receive what is passed 139 to Estimator in `config` parameter, or the default `config`. 140 Allows updating things in your model_fn based on configuration 141 such as `num_ps_replicas`, or `model_dir`. 142 143 * Returns: 144 `EstimatorSpec` 145 146 model_dir: Directory to save model parameters, graph and etc. This can 147 also be used to load checkpoints from the directory into a estimator to 148 continue training a previously saved model. If `PathLike` object, the 149 path will be resolved. If `None`, the model_dir in `config` will be used 150 if set. If both are set, they must be same. If both are `None`, a 151 temporary directory will be used. 152 config: Configuration object. 153 params: `dict` of hyper parameters that will be passed into `model_fn`. 154 Keys are names of parameters, values are basic python types. 155 warm_start_from: Optional string filepath to a checkpoint to warm-start 156 from, or a `tf.estimator.WarmStartSettings` object to 157 fully configure warm-starting. If the string filepath is 158 provided instead of a `WarmStartSettings`, then all 159 variables are warm-started, and it is assumed that 160 vocabularies and Tensor names are unchanged. 161 162 Raises: 163 RuntimeError: If eager execution is enabled. 164 ValueError: parameters of `model_fn` don't match `params`. 165 ValueError: if this is called via a subclass and if that class overrides 166 a member of `Estimator`. 167 """ 168 if context.in_eager_mode(): 169 raise RuntimeError( 170 'Estimators are not supported when eager execution is enabled.') 171 172 Estimator._assert_members_are_not_overridden(self) 173 174 if config is None: 175 self._config = run_config.RunConfig() 176 logging.info('Using default config.') 177 else: 178 if not isinstance(config, run_config.RunConfig): 179 raise ValueError( 180 'config must be an instance of RunConfig, but provided %s.' % 181 config) 182 self._config = config 183 184 # Model directory. 185 model_dir = compat_internal.path_to_str(model_dir) 186 if (model_dir is not None) and (self._config.model_dir is not None): 187 if model_dir != self._config.model_dir: 188 # TODO(alanyee): remove this suppression after it is no longer needed 189 # pylint: disable=g-doc-exception 190 raise ValueError( 191 "model_dir are set both in constructor and RunConfig, but with " 192 "different values. In constructor: '{}', in RunConfig: " 193 "'{}' ".format(model_dir, self._config.model_dir)) 194 # pylint: enable=g-doc-exception 195 196 self._model_dir = model_dir or self._config.model_dir 197 if self._model_dir is None: 198 self._model_dir = tempfile.mkdtemp() 199 logging.warning('Using temporary folder as model directory: %s', 200 self._model_dir) 201 if self._config.model_dir is None: 202 self._config = self._config.replace(model_dir=self._model_dir) 203 logging.info('Using config: %s', str(vars(self._config))) 204 205 if self._config.session_config is None: 206 self._session_config = config_pb2.ConfigProto(allow_soft_placement=True) 207 else: 208 self._session_config = self._config.session_config 209 210 self._device_fn = _get_replica_device_setter(self._config) 211 212 if model_fn is None: 213 raise ValueError('model_fn must be provided to Estimator.') 214 _verify_model_fn_args(model_fn, params) 215 self._model_fn = model_fn 216 self._params = copy.deepcopy(params or {}) 217 218 # pylint: disable=protected-access 219 self._warm_start_settings = ( 220 warm_starting_util._get_default_warm_start_settings(warm_start_from)) 221 # pylint: enable=protected-access 222 223 @property 224 def model_dir(self): 225 return self._model_dir 226 227 @property 228 def config(self): 229 return copy.deepcopy(self._config) 230 231 @property 232 def params(self): 233 return copy.deepcopy(self._params) 234 235 @property 236 def model_fn(self): 237 """Returns the model_fn which is bound to self.params. 238 239 Returns: 240 The model_fn with following signature: 241 `def model_fn(features, labels, mode, config)` 242 """ 243 244 def public_model_fn(features, labels, mode, config): 245 return self._call_model_fn(features, labels, mode, config) 246 247 return public_model_fn 248 249 # TODO(ispir): support a list of names 250 def get_variable_value(self, name): 251 """Returns value of the variable given by name. 252 253 Args: 254 name: string or a list of string, name of the tensor. 255 256 Returns: 257 Numpy array - value of the tensor. 258 259 Raises: 260 ValueError: If the Estimator has not produced a checkpoint yet. 261 """ 262 _check_checkpoint_available(self.model_dir) 263 return training.load_variable(self.model_dir, name) 264 265 def get_variable_names(self): 266 """Returns list of all variable names in this model. 267 268 Returns: 269 List of names. 270 271 Raises: 272 ValueError: If the Estimator has not produced a checkpoint yet. 273 """ 274 _check_checkpoint_available(self.model_dir) 275 return [name for name, _ in training.list_variables(self.model_dir)] 276 277 def latest_checkpoint(self): 278 """Finds the filename of latest saved checkpoint file in `model_dir`. 279 280 Returns: 281 The full path to the latest checkpoint or `None` if no checkpoint was 282 found. 283 """ 284 return saver.latest_checkpoint(self.model_dir) 285 286 def train(self, 287 input_fn, 288 hooks=None, 289 steps=None, 290 max_steps=None, 291 saving_listeners=None): 292 """Trains a model given training data input_fn. 293 294 Args: 295 input_fn: A function that provides input data for training as minibatches. 296 See @{$get_started/premade_estimators#create_input_functions} for more 297 information. The function should construct and return one of 298 the following: 299 300 * A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a 301 tuple (features, labels) with same constraints as below. 302 * A tuple (features, labels): Where features is a `Tensor` or a 303 dictionary of string feature name to `Tensor` and labels is a 304 `Tensor` or a dictionary of string label name to `Tensor`. Both 305 features and labels are consumed by `model_fn`. They should satisfy 306 the expectation of `model_fn` from inputs. 307 308 hooks: List of `SessionRunHook` subclass instances. Used for callbacks 309 inside the training loop. 310 steps: Number of steps for which to train model. If `None`, train forever 311 or train until input_fn generates the `OutOfRange` error or 312 `StopIteration` exception. 'steps' works incrementally. If you call two 313 times train(steps=10) then training occurs in total 20 steps. If 314 `OutOfRange` or `StopIteration` occurs in the middle, training stops 315 before 20 steps. If you don't want to have incremental behavior please 316 set `max_steps` instead. If set, `max_steps` must be `None`. 317 max_steps: Number of total steps for which to train model. If `None`, 318 train forever or train until input_fn generates the `OutOfRange` error 319 or `StopIteration` exception. If set, `steps` must be `None`. If 320 `OutOfRange` or `StopIteration` occurs in the middle, training stops 321 before `max_steps` steps. 322 Two calls to `train(steps=100)` means 200 training 323 iterations. On the other hand, two calls to `train(max_steps=100)` means 324 that the second call will not do any iteration since first call did 325 all 100 steps. 326 saving_listeners: list of `CheckpointSaverListener` objects. Used for 327 callbacks that run immediately before or after checkpoint savings. 328 329 Returns: 330 `self`, for chaining. 331 332 Raises: 333 ValueError: If both `steps` and `max_steps` are not `None`. 334 ValueError: If either `steps` or `max_steps` is <= 0. 335 """ 336 if (steps is not None) and (max_steps is not None): 337 raise ValueError('Can not provide both steps and max_steps.') 338 if steps is not None and steps <= 0: 339 raise ValueError('Must specify steps > 0, given: {}'.format(steps)) 340 if max_steps is not None and max_steps <= 0: 341 raise ValueError( 342 'Must specify max_steps > 0, given: {}'.format(max_steps)) 343 344 if max_steps is not None: 345 start_step = _load_global_step_from_checkpoint_dir(self._model_dir) 346 if max_steps <= start_step: 347 logging.info('Skipping training since max_steps has already saved.') 348 return self 349 350 hooks = _check_hooks_type(hooks) 351 hooks.extend(self._convert_train_steps_to_hooks(steps, max_steps)) 352 353 saving_listeners = _check_listeners_type(saving_listeners) 354 loss = self._train_model(input_fn, hooks, saving_listeners) 355 logging.info('Loss for final step: %s.', loss) 356 return self 357 358 def _convert_train_steps_to_hooks(self, steps, max_steps): 359 if steps is not None or max_steps is not None: 360 return [training.StopAtStepHook(steps, max_steps)] 361 else: 362 return [] 363 364 def evaluate(self, input_fn, steps=None, hooks=None, checkpoint_path=None, 365 name=None): 366 """Evaluates the model given evaluation data input_fn. 367 368 For each step, calls `input_fn`, which returns one batch of data. 369 Evaluates until: 370 - `steps` batches are processed, or 371 - `input_fn` raises an end-of-input exception (`OutOfRangeError` or 372 `StopIteration`). 373 374 Args: 375 input_fn: A function that constructs the input data for evaluation. 376 See @{$get_started/premade_estimators#create_input_functions} for more 377 information. The function should construct and return one of 378 the following: 379 380 * A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a 381 tuple (features, labels) with same constraints as below. 382 * A tuple (features, labels): Where features is a `Tensor` or a 383 dictionary of string feature name to `Tensor` and labels is a 384 `Tensor` or a dictionary of string label name to `Tensor`. Both 385 features and labels are consumed by `model_fn`. They should satisfy 386 the expectation of `model_fn` from inputs. 387 388 steps: Number of steps for which to evaluate model. If `None`, evaluates 389 until `input_fn` raises an end-of-input exception. 390 hooks: List of `SessionRunHook` subclass instances. Used for callbacks 391 inside the evaluation call. 392 checkpoint_path: Path of a specific checkpoint to evaluate. If `None`, the 393 latest checkpoint in `model_dir` is used. 394 name: Name of the evaluation if user needs to run multiple evaluations on 395 different data sets, such as on training data vs test data. Metrics for 396 different evaluations are saved in separate folders, and appear 397 separately in tensorboard. 398 399 Returns: 400 A dict containing the evaluation metrics specified in `model_fn` keyed by 401 name, as well as an entry `global_step` which contains the value of the 402 global step for which this evaluation was performed. 403 404 Raises: 405 ValueError: If `steps <= 0`. 406 ValueError: If no model has been trained, namely `model_dir`, or the 407 given `checkpoint_path` is empty. 408 """ 409 hooks = _check_hooks_type(hooks) 410 hooks.extend(self._convert_eval_steps_to_hooks(steps)) 411 412 return self._evaluate_model( 413 input_fn=input_fn, 414 hooks=hooks, 415 checkpoint_path=checkpoint_path, 416 name=name) 417 418 def _convert_eval_steps_to_hooks(self, steps): 419 if steps is None: 420 return [] 421 422 if steps <= 0: 423 raise ValueError('Must specify steps > 0, given: {}'.format(steps)) 424 return [evaluation._StopAfterNEvalsHook(num_evals=steps)] # pylint: disable=protected-access 425 426 def predict(self, 427 input_fn, 428 predict_keys=None, 429 hooks=None, 430 checkpoint_path=None, 431 yield_single_examples=True): 432 """Yields predictions for given features. 433 434 Args: 435 input_fn: A function that constructs the features. Prediction continues 436 until `input_fn` raises an end-of-input exception (`OutOfRangeError` or 437 `StopIteration`). 438 See @{$get_started/premade_estimators#create_input_functions} for more 439 information. The function should construct and return one of 440 the following: 441 442 * A 'tf.data.Dataset' object: Outputs of `Dataset` object must have 443 same constraints as below. 444 * features: A `Tensor` or a dictionary of string feature name to 445 `Tensor`. features are consumed by `model_fn`. They should satisfy 446 the expectation of `model_fn` from inputs. 447 * A tuple, in which case the first item is extracted as features. 448 449 predict_keys: list of `str`, name of the keys to predict. It is used if 450 the `EstimatorSpec.predictions` is a `dict`. If `predict_keys` is used 451 then rest of the predictions will be filtered from the dictionary. If 452 `None`, returns all. 453 hooks: List of `SessionRunHook` subclass instances. Used for callbacks 454 inside the prediction call. 455 checkpoint_path: Path of a specific checkpoint to predict. If `None`, the 456 latest checkpoint in `model_dir` is used. 457 yield_single_examples: If False, yield the whole batch as returned by the 458 model_fn instead of decomposing the batch into individual elements. This 459 is useful if model_fn return some tensor with first dimension not 460 equal to the batch size 461 462 Yields: 463 Evaluated values of `predictions` tensors. 464 465 Raises: 466 ValueError: Could not find a trained model in model_dir. 467 ValueError: if batch length of predictions are not same and 468 yield_single_examples is True. 469 ValueError: If there is a conflict between `predict_keys` and 470 `predictions`. For example if `predict_keys` is not `None` but 471 `EstimatorSpec.predictions` is not a `dict`. 472 """ 473 hooks = _check_hooks_type(hooks) 474 # Check that model has been trained. 475 if not checkpoint_path: 476 checkpoint_path = saver.latest_checkpoint(self._model_dir) 477 if not checkpoint_path: 478 raise ValueError('Could not find trained model in model_dir: {}.'.format( 479 self._model_dir)) 480 481 with ops.Graph().as_default() as g: 482 random_seed.set_random_seed(self._config.tf_random_seed) 483 self._create_and_assert_global_step(g) 484 features, input_hooks = self._get_features_from_input_fn( 485 input_fn, model_fn_lib.ModeKeys.PREDICT) 486 estimator_spec = self._call_model_fn( 487 features, None, model_fn_lib.ModeKeys.PREDICT, self.config) 488 predictions = self._extract_keys(estimator_spec.predictions, predict_keys) 489 all_hooks = list(input_hooks) 490 all_hooks.extend(hooks) 491 all_hooks.extend(list(estimator_spec.prediction_hooks or [])) 492 with training.MonitoredSession( 493 session_creator=training.ChiefSessionCreator( 494 checkpoint_filename_with_path=checkpoint_path, 495 master=self._config.master, 496 scaffold=estimator_spec.scaffold, 497 config=self._session_config), 498 hooks=all_hooks) as mon_sess: 499 while not mon_sess.should_stop(): 500 preds_evaluated = mon_sess.run(predictions) 501 if not yield_single_examples: 502 yield preds_evaluated 503 elif not isinstance(predictions, dict): 504 for pred in preds_evaluated: 505 yield pred 506 else: 507 for i in range(self._extract_batch_length(preds_evaluated)): 508 yield { 509 key: value[i] 510 for key, value in six.iteritems(preds_evaluated) 511 } 512 513 def _assert_members_are_not_overridden(self): 514 """Asserts members of `Estimator` are not overridden.""" 515 allowed_overrides = set([ 516 '_call_input_fn', '_create_global_step', 517 '_convert_train_steps_to_hooks', '_convert_eval_steps_to_hooks', 518 '_tf_api_names' 519 ]) 520 estimator_members = set([m for m in Estimator.__dict__.keys() 521 if not m.startswith('__')]) 522 subclass_members = set(self.__class__.__dict__.keys()) 523 common_members = estimator_members & subclass_members - allowed_overrides 524 overridden_members = [ 525 m for m in common_members 526 if Estimator.__dict__[m] != self.__class__.__dict__[m]] 527 if overridden_members: 528 raise ValueError( 529 'Subclasses of Estimator cannot override members of Estimator. ' 530 '{} does override {}'.format(self.__class__, overridden_members)) 531 532 def export_savedmodel( 533 self, export_dir_base, serving_input_receiver_fn, 534 assets_extra=None, 535 as_text=False, 536 checkpoint_path=None, 537 strip_default_attrs=False): 538 # pylint: disable=line-too-long 539 """Exports inference graph as a SavedModel into given dir. 540 541 For a detailed guide, see 542 @{$saved_model#using_savedmodel_with_estimators$Using SavedModel with Estimators}. 543 544 This method builds a new graph by first calling the 545 serving_input_receiver_fn to obtain feature `Tensor`s, and then calling 546 this `Estimator`'s model_fn to generate the model graph based on those 547 features. It restores the given checkpoint (or, lacking that, the most 548 recent checkpoint) into this graph in a fresh session. Finally it creates 549 a timestamped export directory below the given export_dir_base, and writes 550 a `SavedModel` into it containing a single `MetaGraphDef` saved from this 551 session. 552 553 The exported `MetaGraphDef` will provide one `SignatureDef` for each 554 element of the export_outputs dict returned from the model_fn, named using 555 the same keys. One of these keys is always 556 signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, indicating which 557 signature will be served when a serving request does not specify one. 558 For each signature, the outputs are provided by the corresponding 559 `ExportOutput`s, and the inputs are always the input receivers provided by 560 the serving_input_receiver_fn. 561 562 Extra assets may be written into the SavedModel via the assets_extra 563 argument. This should be a dict, where each key gives a destination path 564 (including the filename) relative to the assets.extra directory. The 565 corresponding value gives the full path of the source file to be copied. 566 For example, the simple case of copying a single file without renaming it 567 is specified as `{'my_asset_file.txt': '/path/to/my_asset_file.txt'}`. 568 569 Args: 570 export_dir_base: A string containing a directory in which to create 571 timestamped subdirectories containing exported SavedModels. 572 serving_input_receiver_fn: A function that takes no argument and 573 returns a `ServingInputReceiver`. 574 assets_extra: A dict specifying how to populate the assets.extra directory 575 within the exported SavedModel, or `None` if no extra assets are needed. 576 as_text: whether to write the SavedModel proto in text format. 577 checkpoint_path: The checkpoint path to export. If `None` (the default), 578 the most recent checkpoint found within the model directory is chosen. 579 strip_default_attrs: Boolean. If `True`, default-valued attributes will be 580 removed from the NodeDefs. For a detailed guide, see 581 [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). 582 583 Returns: 584 The string path to the exported directory. 585 586 Raises: 587 ValueError: if no serving_input_receiver_fn is provided, no export_outputs 588 are provided, or no checkpoint can be found. 589 """ 590 # pylint: enable=line-too-long 591 if serving_input_receiver_fn is None: 592 raise ValueError('serving_input_receiver_fn must be defined.') 593 594 with ops.Graph().as_default() as g: 595 self._create_and_assert_global_step(g) 596 random_seed.set_random_seed(self._config.tf_random_seed) 597 serving_input_receiver = serving_input_receiver_fn() 598 599 # Call the model_fn and collect the export_outputs. 600 estimator_spec = self._call_model_fn( 601 features=serving_input_receiver.features, 602 labels=None, 603 mode=model_fn_lib.ModeKeys.PREDICT, 604 config=self.config) 605 606 # Build the SignatureDefs from receivers and all outputs 607 signature_def_map = build_all_signature_defs( 608 serving_input_receiver.receiver_tensors, 609 estimator_spec.export_outputs, 610 serving_input_receiver.receiver_tensors_alternatives) 611 612 if not checkpoint_path: 613 # Locate the latest checkpoint 614 checkpoint_path = saver.latest_checkpoint(self._model_dir) 615 if not checkpoint_path: 616 raise ValueError("Couldn't find trained model at %s." % self._model_dir) 617 618 export_dir = get_timestamped_export_dir(export_dir_base) 619 temp_export_dir = get_temp_export_dir(export_dir) 620 621 # TODO(soergel): Consider whether MonitoredSession makes sense here 622 with tf_session.Session(config=self._session_config) as session: 623 624 saver_for_restore = estimator_spec.scaffold.saver or saver.Saver( 625 sharded=True) 626 saver_for_restore.restore(session, checkpoint_path) 627 628 # pylint: disable=protected-access 629 local_init_op = ( 630 estimator_spec.scaffold.local_init_op or 631 monitored_session.Scaffold._default_local_init_op()) 632 # pylint: enable=protected-access 633 634 # Perform the export 635 builder = saved_model_builder.SavedModelBuilder(temp_export_dir) 636 builder.add_meta_graph_and_variables( 637 session, [tag_constants.SERVING], 638 signature_def_map=signature_def_map, 639 assets_collection=ops.get_collection( 640 ops.GraphKeys.ASSET_FILEPATHS), 641 legacy_init_op=local_init_op, 642 strip_default_attrs=strip_default_attrs) 643 builder.save(as_text) 644 645 # Add the extra assets 646 if assets_extra: 647 assets_extra_path = os.path.join(compat.as_bytes(temp_export_dir), 648 compat.as_bytes('assets.extra')) 649 for dest_relative, source in assets_extra.items(): 650 dest_absolute = os.path.join(compat.as_bytes(assets_extra_path), 651 compat.as_bytes(dest_relative)) 652 dest_path = os.path.dirname(dest_absolute) 653 gfile.MakeDirs(dest_path) 654 gfile.Copy(source, dest_absolute) 655 656 gfile.Rename(temp_export_dir, export_dir) 657 return export_dir 658 659 def _get_features_from_input_fn(self, input_fn, mode): 660 """Extracts the `features` from return values of `input_fn`.""" 661 result = self._call_input_fn(input_fn, mode) 662 input_hooks = [] 663 if isinstance(result, dataset_ops.Dataset): 664 iterator = result.make_initializable_iterator() 665 input_hooks.append(_DatasetInitializerHook(iterator)) 666 result = iterator.get_next() 667 if isinstance(result, (list, tuple)): 668 # Unconditionally drop the label (the second element of result). 669 result = result[0] 670 671 if not _has_dataset_or_queue_runner(result): 672 logging.warning('Input graph does not use tf.data.Dataset or contain a ' 673 'QueueRunner. That means predict yields forever. ' 674 'This is probably a mistake.') 675 return result, input_hooks 676 677 def _get_features_and_labels_from_input_fn(self, input_fn, mode): 678 """Extracts the `features` and labels from return values of `input_fn`.""" 679 result = self._call_input_fn(input_fn, mode) 680 input_hooks = [] 681 if isinstance(result, dataset_ops.Dataset): 682 iterator = result.make_initializable_iterator() 683 input_hooks.append(_DatasetInitializerHook(iterator)) 684 result = iterator.get_next() 685 if isinstance(result, (list, tuple)): 686 if len(result) != 2: 687 raise ValueError( 688 'input_fn should return (features, labels) as a len 2 tuple.') 689 return result[0], result[1], input_hooks 690 return result, None, input_hooks 691 692 def _extract_batch_length(self, preds_evaluated): 693 """Extracts batch length of predictions.""" 694 batch_length = None 695 for key, value in six.iteritems(preds_evaluated): 696 batch_length = batch_length or value.shape[0] 697 if value.shape[0] != batch_length: 698 raise ValueError('Batch length of predictions should be same. %s has ' 699 'different batch length then others.' % key) 700 return batch_length 701 702 def _extract_keys(self, predictions, predict_keys): 703 """Extracts `predict_keys` from `predictions`.""" 704 if not predict_keys: 705 return predictions 706 if not isinstance(predictions, dict): 707 raise ValueError( 708 'predict_keys argument is not valid in case of non-dict predictions.') 709 existing_keys = predictions.keys() 710 predictions = { 711 key: value 712 for key, value in six.iteritems(predictions) if key in predict_keys 713 } 714 if not predictions: 715 raise ValueError('Expected to run at least one output from %s, ' 716 'provided %s.' % (existing_keys, predict_keys)) 717 return predictions 718 719 def _create_global_step(self, graph): 720 """Creates the global step tensor in graph. 721 722 The global step tensor must be an integer type with name 'global_step' and 723 be added to the collection ${tf.GraphKeys.GLOBAL_STEP}. 724 725 Args: 726 graph: The graph in which to create the global step tensor. 727 728 Returns: 729 The global step `Tensor`. 730 """ 731 return training.create_global_step(graph) 732 733 def _create_and_assert_global_step(self, graph): 734 """Creates and asserts properties of the global step. 735 736 Args: 737 graph: The graph in which to create the global step tensor. 738 739 Returns: 740 The global step `Tensor`. 741 """ 742 step = self._create_global_step(graph) 743 assert step == training.get_global_step() 744 assert step.dtype.is_integer 745 return step 746 747 def _call_input_fn(self, input_fn, mode): 748 """Calls the input function. 749 750 Args: 751 input_fn: The input function. 752 mode: ModeKeys 753 754 Returns: 755 Either features or (features, labels) where features and labels are: 756 features - `Tensor` or dictionary of string feature name to `Tensor`. 757 labels - `Tensor` or dictionary of `Tensor` with labels. 758 759 Raises: 760 ValueError: if input_fn takes invalid arguments. 761 """ 762 input_fn_args = util.fn_args(input_fn) 763 kwargs = {} 764 if 'mode' in input_fn_args: 765 kwargs['mode'] = mode 766 if 'params' in input_fn_args: 767 kwargs['params'] = self.params 768 if 'config' in input_fn_args: 769 kwargs['config'] = self.config 770 with ops.device('/cpu:0'): 771 return input_fn(**kwargs) 772 773 def _call_model_fn(self, features, labels, mode, config): 774 """Calls model function. 775 776 Args: 777 features: features dict. 778 labels: labels dict. 779 mode: ModeKeys 780 config: RunConfig 781 782 Returns: 783 An `EstimatorSpec` object. 784 785 Raises: 786 ValueError: if model_fn returns invalid objects. 787 """ 788 model_fn_args = util.fn_args(self._model_fn) 789 kwargs = {} 790 if 'labels' in model_fn_args: 791 kwargs['labels'] = labels 792 else: 793 if labels is not None: 794 raise ValueError( 795 'model_fn does not take labels, but input_fn returns labels.') 796 if 'mode' in model_fn_args: 797 kwargs['mode'] = mode 798 if 'params' in model_fn_args: 799 kwargs['params'] = self.params 800 if 'config' in model_fn_args: 801 kwargs['config'] = config 802 803 logging.info('Calling model_fn.') 804 model_fn_results = self._model_fn(features=features, **kwargs) 805 logging.info('Done calling model_fn.') 806 807 if not isinstance(model_fn_results, model_fn_lib.EstimatorSpec): 808 raise ValueError('model_fn should return an EstimatorSpec.') 809 810 return model_fn_results 811 812 def _train_model(self, input_fn, hooks, saving_listeners): 813 worker_hooks = [] 814 with ops.Graph().as_default() as g, g.device(self._device_fn): 815 random_seed.set_random_seed(self._config.tf_random_seed) 816 global_step_tensor = self._create_and_assert_global_step(g) 817 training_util._get_or_create_global_step_read() # pylint: disable=protected-access 818 features, labels, input_hooks = ( 819 self._get_features_and_labels_from_input_fn( 820 input_fn, model_fn_lib.ModeKeys.TRAIN)) 821 worker_hooks.extend(input_hooks) 822 estimator_spec = self._call_model_fn( 823 features, labels, model_fn_lib.ModeKeys.TRAIN, self.config) 824 825 if self._warm_start_settings: 826 logging.info('Warm-starting with WarmStartSettings: %s' % 827 (self._warm_start_settings,)) 828 # pylint: disable=protected-access 829 warm_starting_util._warm_start(self._warm_start_settings) 830 # pylint: enable=protected-access 831 # Check if the user created a loss summary, and add one if they didn't. 832 # We assume here that the summary is called 'loss'. If it is not, we will 833 # make another one with the name 'loss' to ensure it shows up in the right 834 # graph in TensorBoard. 835 if not any([x.op.name == 'loss' 836 for x in ops.get_collection(ops.GraphKeys.SUMMARIES)]): 837 summary.scalar('loss', estimator_spec.loss) 838 ops.add_to_collection(ops.GraphKeys.LOSSES, estimator_spec.loss) 839 worker_hooks.extend(hooks) 840 worker_hooks.extend([ 841 training.NanTensorHook(estimator_spec.loss), 842 training.LoggingTensorHook( 843 { 844 'loss': estimator_spec.loss, 845 'step': global_step_tensor 846 }, 847 every_n_iter=100) 848 ]) 849 worker_hooks.extend(estimator_spec.training_hooks) 850 851 if not (estimator_spec.scaffold.saver or 852 ops.get_collection(ops.GraphKeys.SAVERS)): 853 ops.add_to_collection( 854 ops.GraphKeys.SAVERS, 855 training.Saver( 856 sharded=True, 857 max_to_keep=self._config.keep_checkpoint_max, 858 keep_checkpoint_every_n_hours=( 859 self._config.keep_checkpoint_every_n_hours), 860 defer_build=True, 861 save_relative_paths=True)) 862 863 chief_hooks = [] 864 all_hooks = worker_hooks + list(estimator_spec.training_chief_hooks) 865 saver_hooks = [ 866 h for h in all_hooks if isinstance(h, training.CheckpointSaverHook)] 867 if (self._config.save_checkpoints_secs or 868 self._config.save_checkpoints_steps): 869 if not saver_hooks: 870 chief_hooks = [ 871 training.CheckpointSaverHook( 872 self._model_dir, 873 save_secs=self._config.save_checkpoints_secs, 874 save_steps=self._config.save_checkpoints_steps, 875 scaffold=estimator_spec.scaffold) 876 ] 877 saver_hooks = [chief_hooks[0]] 878 if saving_listeners: 879 if not saver_hooks: 880 raise ValueError( 881 'There should be a CheckpointSaverHook to use saving_listeners. ' 882 'Please set one of the RunConfig.save_checkpoints_steps or ' 883 'RunConfig.save_checkpoints_secs.') 884 else: 885 # It is expected to have one CheckpointSaverHook. If multiple, we pick 886 # up the first one to add listener. 887 saver_hooks[0]._listeners.extend(saving_listeners) # pylint: disable=protected-access 888 with training.MonitoredTrainingSession( 889 master=self._config.master, 890 is_chief=self._config.is_chief, 891 checkpoint_dir=self._model_dir, 892 scaffold=estimator_spec.scaffold, 893 hooks=worker_hooks, 894 chief_only_hooks=( 895 tuple(chief_hooks) + tuple(estimator_spec.training_chief_hooks)), 896 save_checkpoint_secs=0, # Saving is handled by a hook. 897 save_summaries_steps=self._config.save_summary_steps, 898 config=self._session_config, 899 log_step_count_steps=self._config.log_step_count_steps) as mon_sess: 900 loss = None 901 while not mon_sess.should_stop(): 902 _, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss]) 903 return loss 904 905 def _evaluate_model(self, 906 input_fn, 907 hooks=None, 908 checkpoint_path=None, 909 name=''): 910 """Evaluates the model using the training.evaluation library.""" 911 # Check that model has been trained (if nothing has been set explicitly). 912 if not checkpoint_path: 913 latest_path = saver.latest_checkpoint(self._model_dir) 914 if not latest_path: 915 raise ValueError('Could not find trained model in model_dir: {}.'. 916 format(self._model_dir)) 917 checkpoint_path = latest_path 918 919 # Setup output directory. 920 eval_dir = os.path.join(self._model_dir, 'eval' if not name else 921 'eval_' + name) 922 923 with ops.Graph().as_default() as g: 924 random_seed.set_random_seed(self._config.tf_random_seed) 925 global_step_tensor = self._create_and_assert_global_step(g) 926 features, labels, input_hooks = ( 927 self._get_features_and_labels_from_input_fn( 928 input_fn, model_fn_lib.ModeKeys.EVAL)) 929 estimator_spec = self._call_model_fn( 930 features, labels, model_fn_lib.ModeKeys.EVAL, self.config) 931 932 if model_fn_lib.LOSS_METRIC_KEY in estimator_spec.eval_metric_ops: 933 raise ValueError( 934 'Metric with name "%s" is not allowed, because Estimator ' % ( 935 model_fn_lib.LOSS_METRIC_KEY) + 936 'already defines a default metric with the same name.') 937 estimator_spec.eval_metric_ops[ 938 model_fn_lib.LOSS_METRIC_KEY] = metrics_lib.mean(estimator_spec.loss) 939 940 update_op, eval_dict = _extract_metric_update_ops( 941 estimator_spec.eval_metric_ops) 942 943 if ops.GraphKeys.GLOBAL_STEP in eval_dict: 944 raise ValueError( 945 'Metric with name `global_step` is not allowed, because Estimator ' 946 'already defines a default metric with the same name.') 947 eval_dict[ops.GraphKeys.GLOBAL_STEP] = global_step_tensor 948 949 all_hooks = list(input_hooks) 950 all_hooks.extend(hooks) 951 all_hooks.extend(list(estimator_spec.evaluation_hooks or [])) 952 953 eval_results = evaluation._evaluate_once( # pylint: disable=protected-access 954 checkpoint_path=checkpoint_path, 955 master=self._config.evaluation_master, 956 scaffold=estimator_spec.scaffold, 957 eval_ops=update_op, 958 final_ops=eval_dict, 959 hooks=all_hooks, 960 config=self._session_config) 961 962 _write_dict_to_summary( 963 output_dir=eval_dir, 964 dictionary=eval_results, 965 current_global_step=eval_results[ops.GraphKeys.GLOBAL_STEP]) 966 967 return eval_results 968 969 970def _check_checkpoint_available(model_dir): 971 latest_path = saver.latest_checkpoint(model_dir) 972 if not latest_path: 973 raise ValueError( 974 'Could not find trained model in model_dir: {}.'.format(model_dir)) 975 976 977def _check_hooks_type(hooks): 978 """Returns hooks if all are SessionRunHook, raises TypeError otherwise.""" 979 hooks = list(hooks or []) 980 for h in hooks: 981 if not isinstance(h, training.SessionRunHook): 982 raise TypeError('Hooks must be a SessionRunHook, given: {}'.format(h)) 983 return hooks 984 985 986def _check_listeners_type(saving_listeners): 987 """Check listeners type.""" 988 listeners = list(saving_listeners or []) 989 for l in listeners: 990 if not isinstance(l, training.CheckpointSaverListener): 991 raise TypeError( 992 'saving_listeners must be a list of CheckpointSaverListener, ' 993 'given: {}'.format(l)) 994 return listeners 995 996 997def _get_replica_device_setter(config): 998 """Creates a replica device setter if required as a default device_fn. 999 1000 `Estimator` uses ReplicaDeviceSetter as a default device placer. It sets the 1001 distributed related arguments such as number of ps_replicas based on given 1002 config. 1003 1004 Args: 1005 config: A `RunConfig` instance. 1006 1007 Returns: 1008 A replica device setter, or None. 1009 """ 1010 ps_ops = [ 1011 'Variable', 'VariableV2', 'AutoReloadVariable', 'MutableHashTable', 1012 'MutableHashTableV2', 'MutableHashTableOfTensors', 1013 'MutableHashTableOfTensorsV2', 'MutableDenseHashTable', 1014 'MutableDenseHashTableV2', 'VarHandleOp' 1015 ] 1016 1017 if config.task_type: 1018 worker_device = '/job:%s/task:%d' % (config.task_type, config.task_id) 1019 else: 1020 worker_device = '/job:worker' 1021 1022 if config.num_ps_replicas > 0: 1023 return training.replica_device_setter( 1024 ps_tasks=config.num_ps_replicas, 1025 worker_device=worker_device, 1026 merge_devices=True, 1027 ps_ops=ps_ops, 1028 cluster=config.cluster_spec) 1029 else: 1030 return None 1031 1032 1033def _verify_model_fn_args(model_fn, params): 1034 """Verifies model fn arguments.""" 1035 args = set(util.fn_args(model_fn)) 1036 if 'features' not in args: 1037 raise ValueError('model_fn (%s) must include features argument.' % model_fn) 1038 if params is not None and 'params' not in args: 1039 raise ValueError('model_fn (%s) does not include params argument, ' 1040 'but params (%s) is passed to Estimator.' % (model_fn, 1041 params)) 1042 if params is None and 'params' in args: 1043 logging.warning('Estimator\'s model_fn (%s) includes params ' 1044 'argument, but params are not passed to Estimator.', 1045 model_fn) 1046 non_valid_args = list(args - _VALID_MODEL_FN_ARGS) 1047 if non_valid_args: 1048 raise ValueError('model_fn (%s) has following not expected args: %s' % 1049 (model_fn, non_valid_args)) 1050 1051 1052def _load_global_step_from_checkpoint_dir(checkpoint_dir): 1053 try: 1054 checkpoint_reader = training.NewCheckpointReader( 1055 training.latest_checkpoint(checkpoint_dir)) 1056 return checkpoint_reader.get_tensor(ops.GraphKeys.GLOBAL_STEP) 1057 except: # pylint: disable=bare-except 1058 return 0 1059 1060 1061def _extract_metric_update_ops(eval_dict): 1062 """Separate update operations from metric value operations.""" 1063 update_ops = [] 1064 value_ops = {} 1065 # Sort metrics lexicographically so graph is identical every time. 1066 for name, metric_ops in sorted(six.iteritems(eval_dict)): 1067 value_ops[name] = metric_ops[0] 1068 update_ops.append(metric_ops[1]) 1069 1070 if update_ops: 1071 update_op = control_flow_ops.group(*update_ops) 1072 else: 1073 update_op = None 1074 1075 return update_op, value_ops 1076 1077 1078def _dict_to_str(dictionary): 1079 """Get a `str` representation of a `dict`. 1080 1081 Args: 1082 dictionary: The `dict` to be represented as `str`. 1083 1084 Returns: 1085 A `str` representing the `dictionary`. 1086 """ 1087 return ', '.join('%s = %s' % (k, v) 1088 for k, v in sorted(six.iteritems(dictionary))) 1089 1090 1091def _write_dict_to_summary(output_dir, 1092 dictionary, 1093 current_global_step): 1094 """Writes a `dict` into summary file in given output directory. 1095 1096 Args: 1097 output_dir: `str`, directory to write the summary file in. 1098 dictionary: the `dict` to be written to summary file. 1099 current_global_step: `int`, the current global step. 1100 """ 1101 logging.info('Saving dict for global step %d: %s', current_global_step, 1102 _dict_to_str(dictionary)) 1103 summary_writer = writer_cache.FileWriterCache.get(output_dir) 1104 summary_proto = summary_pb2.Summary() 1105 for key in dictionary: 1106 if dictionary[key] is None: 1107 continue 1108 if key == 'global_step': 1109 continue 1110 if (isinstance(dictionary[key], np.float32) or 1111 isinstance(dictionary[key], float)): 1112 summary_proto.value.add(tag=key, simple_value=float(dictionary[key])) 1113 elif (isinstance(dictionary[key], np.int64) or 1114 isinstance(dictionary[key], np.int32) or 1115 isinstance(dictionary[key], int)): 1116 summary_proto.value.add(tag=key, simple_value=int(dictionary[key])) 1117 elif isinstance(dictionary[key], six.binary_type): 1118 try: 1119 summ = summary_pb2.Summary.FromString(dictionary[key]) 1120 for i, _ in enumerate(summ.value): 1121 summ.value[i].tag = key 1122 summary_proto.value.extend(summ.value) 1123 except message.DecodeError: 1124 logging.warn('Skipping summary for %s, cannot parse string to Summary.', 1125 key) 1126 continue 1127 else: 1128 logging.warn( 1129 'Skipping summary for %s, must be a float, np.float32, np.int64, ' 1130 'np.int32 or int or a serialized string of Summary.', key) 1131 summary_writer.add_summary(summary_proto, current_global_step) 1132 summary_writer.flush() 1133 1134 1135def _has_dataset_or_queue_runner(maybe_tensor): 1136 """Returns True if TF dataset or QueueRunner has been used.""" 1137 # Check TF dataset first. Here, we use a simple algorithm to check the top 1138 # level Tensors only, which should be sufficient for most users. 1139 tensors = [x for x in nest.flatten(maybe_tensor) if isinstance(x, ops.Tensor)] 1140 if any([t.op.type == 'IteratorGetNext' for t in tensors]): 1141 return True 1142 1143 # Now, check queue. 1144 return ops.get_default_graph().get_collection(ops.GraphKeys.QUEUE_RUNNERS) 1145 1146 1147class _DatasetInitializerHook(training.SessionRunHook): 1148 1149 def __init__(self, iterator): 1150 self._iterator = iterator 1151 1152 def begin(self): 1153 self._initializer = self._iterator.initializer 1154 1155 def after_create_session(self, session, coord): 1156 del coord 1157 session.run(self._initializer) 1158