134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower#
334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower# Licensed under the Apache License, Version 2.0 (the "License");
434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower# you may not use this file except in compliance with the License.
534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower# You may obtain a copy of the License at
634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower#
734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower#     http://www.apache.org/licenses/LICENSE-2.0
834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower#
934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower# Unless required by applicable law or agreed to in writing, software
1034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower# distributed under the License is distributed on an "AS IS" BASIS,
1134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower# See the License for the specific language governing permissions and
1334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower# limitations under the License.
1434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower# ==============================================================================
1534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower"""Abstractions for the head(s) of a model."""
1634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
1734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerfrom __future__ import absolute_import
1834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerfrom __future__ import division
1934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerfrom __future__ import print_function
2034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
2134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerimport abc
22e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Caiimport collections
2334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
2434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerimport six
2534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
2634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerfrom tensorflow.python.estimator import model_fn
273942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlowerfrom tensorflow.python.estimator import util
2834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerfrom tensorflow.python.estimator.canned import metric_keys
2934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerfrom tensorflow.python.estimator.canned import prediction_keys
3034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerfrom tensorflow.python.estimator.export import export_output
31d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispirfrom tensorflow.python.feature_column import feature_column as feature_column_lib
3234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerfrom tensorflow.python.framework import dtypes
3334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerfrom tensorflow.python.framework import ops
3434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerfrom tensorflow.python.framework import sparse_tensor
3534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerfrom tensorflow.python.ops import array_ops
3634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerfrom tensorflow.python.ops import check_ops
3767c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlowerfrom tensorflow.python.ops import control_flow_ops
38ce21f5e372360451528c9716a743b65308421423Mustafa Ispirfrom tensorflow.python.ops import lookup_ops
3934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerfrom tensorflow.python.ops import math_ops
4034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerfrom tensorflow.python.ops import metrics as metrics_lib
4134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerfrom tensorflow.python.ops import nn
4234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerfrom tensorflow.python.ops import string_ops
4334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerfrom tensorflow.python.ops import weights_broadcast_ops
4434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerfrom tensorflow.python.ops.losses import losses
4502ebc44be9a660a54793973110ae26cf948ffceaMustafa Ispirfrom tensorflow.python.saved_model import signature_constants
46169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispirfrom tensorflow.python.summary import summary
4702ebc44be9a660a54793973110ae26cf948ffceaMustafa Ispir
4802ebc44be9a660a54793973110ae26cf948ffceaMustafa Ispir_DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
4934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
504db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel# The above default is defined by TF Serving, but these next three are just
514db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel# a local convention without any special meaning.
524db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel_CLASSIFY_SERVING_KEY = 'classification'
534db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel_REGRESS_SERVING_KEY = 'regression'
544db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel_PREDICT_SERVING_KEY = 'predict'
554db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel
5634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
5799d51b8da87ba462f66a6be90212677f2cae9e32A. Unique TensorFlower# A LossSpec contains
5820d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower# * a scalar `Tensor` representing reduced weighted training loss
5920d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower# * a scalar `Tensor` representing the unreduced unweighted loss
6020d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower# * a scalar `Tensor` representing the example weights
6199d51b8da87ba462f66a6be90212677f2cae9e32A. Unique TensorFlower# * possibly processed labels (e.g. vocabulary lookup, shape manipulation, etc)
6299d51b8da87ba462f66a6be90212677f2cae9e32A. Unique TensorFlowerLossSpec = collections.namedtuple(
6320d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower    'LossSpec', ['training_loss', 'unreduced_loss', 'weights',
6420d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower                 'processed_labels'])
658d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower
668d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower
670fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlowerdef _summary_key(head_name, val):
680fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower  return '%s/%s' % (val, head_name) if head_name else val
690fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower
700fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower
7134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerclass _Head(object):
7234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  """Interface for the head/top of a model.
7334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
7434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  Given logits (or output of a hidden layer), a Head knows how to compute
7534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  predictions, loss, train_op, metrics and export outputs. It is meant to:
7634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
7734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  1. Simplify writing model_fn and to make model_fn more configurable
7834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  2. Support wide range of machine learning models. Since most heads can work
7934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower     with logits, they can support DNN, RNN, Wide, Wide&Deep,
8034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower     Global objectives, Gradient boosted trees and many other types
8134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower     of machine learning models.
8234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
8334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  Common usage:
8434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  Here is simplified model_fn to build a DNN regression model.
8534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    ```python
8634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    def _my_dnn_model_fn(features, labels, mode, params, config=None):
8734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      # Optionally your callers can pass head to model_fn as a param.
8834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      head = tf.contrib.learn.regression_head(...)
8934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      input = tf.contrib.layers.input_from_feature_columns(features, ...)
9034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      last_hidden_layer_out = tf.contrib.layers.stack(
9134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower          input, tf.contrib.layers.fully_connected, [1000, 500])
9234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      logits = tf.contrib.layers.fully_connected(
9334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower          last_hidden_layer_out, head.logits_dimension, activation_fn=None)
9434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
9534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      def _train_op_fn(loss):
9634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        return optimizer.minimize(loss)
9734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
9834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      return head.create_estimator_spec(
9934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower          features=features,
10034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower          labels=labels,
10134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower          mode=mode,
10234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower          logits=logits,
10334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower          train_op_fn=_train_op_fn)
10434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    ```
10534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
10634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  There are cases where computing and applying gradients can not be meaningfully
10734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  captured with train_op_fn we support (for example, with sync optimizer). In
10834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  such case, you can take the responsibility on your own. Here is a common
10934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  use case,
11034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    ```python
11134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    estimator_spec = head.create_estimator_spec(
11234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        features=features,
11334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        labels=labels,
11434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        mode=mode,
11534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        logits=logits,
11634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        train_op_fn=tf.contrib.learn.no_op_train_fn)
11734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    if mode == model_fn.ModeKeys.TRAIN:
11834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      optimizer = ...
11934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      sync = tf.train.SyncReplicasOptimizer(opt=optimizer, ...)
12034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      update_op = tf.contrib.layers.optimize_loss(optimizer=sync,
12134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower                                                  loss=estimator_spec.loss, ...)
12234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      hooks = [sync.make_session_run_hook(is_chief)]
123b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      ... update train_op and hooks in EstimatorSpec and return
12434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    ```
12534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  """
12634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  __metaclass__ = abc.ABCMeta
12734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
12834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  @abc.abstractproperty
129abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower  def name(self):
130abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower    """The name of this head.
131abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower
132abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower    Returns:
133abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower      A string.
134abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower    """
135abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower    raise NotImplementedError('Calling an abstract method.')
136abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower
137abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower  @abc.abstractproperty
13834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  def logits_dimension(self):
13934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    """Size of the last dimension of the logits `Tensor`.
14034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
14134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    Typically, logits is of shape `[batch_size, logits_dimension]`.
14234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
14334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    Returns:
14434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      The expected size of the `logits` tensor.
14534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    """
14634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    raise NotImplementedError('Calling an abstract method.')
14734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
14834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  @abc.abstractmethod
1498d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower  def create_loss(self, features, mode, logits, labels):
1508d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower    """Returns a loss Tensor from provided logits.
1518d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower
1528d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower    This function is designed to be used by framework developers.  Almost all
1538d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower    users should use create_estimator_spec(), which calls this internally.
1548d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower    `mode` and `features` are most likely not used, but some Head
1558d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower    implementations may require them.
1568d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower
1578d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower    Args:
1588d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower      features: Input `dict` of `Tensor` objects.
1598d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower      mode: Estimator's `ModeKeys`.
1608d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower      logits: logits `Tensor` to be used for loss construction.
1612dab9fd3c89f47dbb0b5f4368084cebb56e03a09A. Unique TensorFlower      labels: Labels `Tensor`, or `dict` of same.
1628d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower
1638d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower    Returns:
16499d51b8da87ba462f66a6be90212677f2cae9e32A. Unique TensorFlower      A LossSpec that contains
16520d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower      * the scalar `Tensor` representing reduced weighted training loss
16620d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower      * the scalar `Tensor` representing the unreduced unweighted loss
16720d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower      * the scalar `Tensor` representing the example weights
16899d51b8da87ba462f66a6be90212677f2cae9e32A. Unique TensorFlower      * possibly processed labels (e.g. vocabulary lookup, shape manipulation,
16999d51b8da87ba462f66a6be90212677f2cae9e32A. Unique TensorFlower        etc.)
17099d51b8da87ba462f66a6be90212677f2cae9e32A. Unique TensorFlower
17199d51b8da87ba462f66a6be90212677f2cae9e32A. Unique TensorFlower      To be extendable in the future.
1728d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower    """
1738d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower    raise NotImplementedError('Calling an abstract method.')
1748d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower
1758d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower  @abc.abstractmethod
17634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  def create_estimator_spec(
177c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower      self, features, mode, logits, labels=None, train_op_fn=None,
178c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower      regularization_losses=None):
17934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    """Returns `EstimatorSpec` that a model_fn can return.
18034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
18134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    Please note that,
18234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    + All args must be passed via name.
18334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
18434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    Args:
18567c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower      features: Input `dict` of `Tensor` or `SparseTensor` objects.
18634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      mode: Estimator's `ModeKeys`.
18734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      logits: logits `Tensor` to be used by the head.
18834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      labels: Labels `Tensor`, or `dict` of same.
18934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      train_op_fn: Function that takes a scalar loss `Tensor` and returns an op
190c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower        to optimize the model with the loss. This is used in TRAIN mode and
191c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower        must not be None. None is allowed in other modes. If you want to
192c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower        optimize loss yourself you can pass `no_op_train_fn` and then use
193c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower        EstimatorSpec.loss to compute and apply gradients.
194c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower      regularization_losses: A list of additional scalar losses to be added to
195c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower        the training loss, such as regularization losses.
19634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
19734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    Returns:
19834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      `EstimatorSpec`.
19934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    """
20034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    raise NotImplementedError('Calling an abstract method.')
20134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
20234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
20367c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlowerdef _check_dense_labels_match_logits_and_reshape(
20467c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower    labels, logits, expected_labels_dimension):
20567c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower  """Checks that labels shape matches logits and reshapes if needed.
20667c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower
20767c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower  Consider logits of shape [D0, D1, ... DN, logits_dimension]. Then labels
20867c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower  shape must be [D0, D1, ... DN, expected_labels_dimension].
20967c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower  If expected_labels_dimension=1, labels could be [D0, D1, ... DN] and this
21067c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower  method reshapes them to [D0, D1, ... DN, 1].
21167c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower
21267c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower  Args:
21367c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower    labels: labels Tensor.
21467c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower    logits: logits Tensor.
21567c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower    expected_labels_dimension: Integer.
21667c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower  Returns:
21767c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower    Validated and reshaped labels Tensor.
21867c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower  Raises:
21967c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower    ValueError: If labels is a SparseTensor.
22067c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower    ValueError: If labels shape is statically defined and fails validation.
22167c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower    OpError: If labels shape is not statically defined and fails validation.
22267c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower  """
22367c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower  if labels is None:
22467c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower    raise ValueError(
22567c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower        'You must provide a labels Tensor. Given: None. '
22667c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower        'Suggested troubleshooting steps: Check that your data contain '
22767c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower        'your label feature. Check that your input_fn properly parses and '
22867c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower        'returns labels.')
22967c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower  with ops.name_scope(None, 'labels', (labels, logits)) as scope:
23067c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower    labels = sparse_tensor.convert_to_tensor_or_sparse_tensor(labels)
23167c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower    if isinstance(labels, sparse_tensor.SparseTensor):
23267c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower      raise ValueError(
23367c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower          'SparseTensor labels are not supported. '
23467c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower          'labels must be a Tensor of shape [D0, D1, ..., DN, %s], '
23567c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower          'e.g. [batch_size, %s]. '
23667c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower          'Suggested Fix (1): Check the label feature in your data. '
23767c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower          'Each example must contain %s value(s). If not, your choice of label '
23867c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower          'was probably incorrect. '
23967c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower          'Suggested Fix (2): In your input_fn, use '
24067c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower          'tf.sparse_tensor_to_dense() to turn labels into a Tensor.'
24167c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower          '' % (expected_labels_dimension, expected_labels_dimension,
24267c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower                expected_labels_dimension))
24367c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower    if (labels.shape.ndims is not None and logits.shape.ndims is not None and
24467c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower        labels.shape.ndims == logits.shape.ndims - 1):
24567c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower      labels = array_ops.expand_dims(labels, -1)
24667c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower    labels_shape = array_ops.shape(labels)
24767c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower    logits_shape = array_ops.shape(logits)
24867c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower    err_msg = (
24967c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower        'labels shape must be [D0, D1, ... DN, {}]. '
25067c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower        'Suggested Fix: check your n_classes argument to the estimator '
25167c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower        'and/or the shape of your label.'.format(expected_labels_dimension))
25267c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower    assert_rank = check_ops.assert_rank_at_least(labels, 2, message=err_msg)
25367c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower    with ops.control_dependencies([assert_rank]):
25467c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower      static_shape = labels.shape
25567c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower      if static_shape.ndims is not None:
25667c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower        dim1 = static_shape[-1]
25767c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower        if (dim1 is not None) and (dim1 != expected_labels_dimension):
25867c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower          raise ValueError(
25967c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower              'Mismatched label shape. '
26067c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower              'Classifier configured with n_classes=%s.  Received %s. '
26167c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower              'Suggested Fix: check your n_classes argument to the estimator '
26267c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower              'and/or the shape of your label.' %
26367c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower              (expected_labels_dimension, dim1))
26467c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower      expected_labels_shape = array_ops.concat(
26567c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower          [logits_shape[:-1], [expected_labels_dimension]], axis=0)
26667c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower      assert_dimension = check_ops.assert_equal(
26767c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower          expected_labels_shape, labels_shape, message=err_msg,
26867c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower          data=['expected_labels_shape: ', expected_labels_shape,
26967c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower                'labels_shape: ', labels_shape])
27067c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower      with ops.control_dependencies([assert_dimension]):
27167c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower        return array_ops.identity(labels, name=scope)
27267c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower
27367c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower
274d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlowerdef _get_weights_and_check_match_logits(
275d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower    features, weight_column, logits, allow_per_logit_weights=False):
276d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower  """Fetches weights from features and checks that the shape matches logits.
27767c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower
27867c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower  Consider logits of shape [D0, D1, ... DN, logits_dimension]. Weights shape
27967c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower  can be either:
280d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower  * [D0, D1, ... DN, logits_dimension] if `allow_per_logit_weights=True`.
2817948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower  * [D0, D1, ... DN, 1]
28267c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower  * [D0, D1, ... DN]: In this case, weights is reshaped into
28367c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower    [D0, D1, ... DN, 1] to work with weight broadcasting rules.
28467c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower
28567c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower  Args:
286d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower    features: The features dict that contains weights.
287d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower    weight_column: The weight column. If not given, this method returns 1.
28867c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower    logits: logits Tensor.
289d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower    allow_per_logit_weights: Boolean. Whether we allow weights along the logits
290d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower      dimension, namely shape `[D0, D1, ... DN, logits_dimension]`.
29167c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower  Returns:
29267c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower    Validated and reshaped weights Tensor.
293d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower  Raises:
294d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower    ValueError: If the weights `Tensor` cannot be cast into float.
29567c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower  """
296d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower  if allow_per_logit_weights:
297d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower    err_msg = (
298d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower        'weights shape must be [D0, D1, ... DN], [D0, D1, ... DN, 1] or '
299d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower        '[D0, D1, ... DN, logits_dimension]')
300d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower  else:
301d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower    err_msg = (
302d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower        'weights shape must be [D0, D1, ... DN] or [D0, D1, ... DN, 1]')
303d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower  with ops.name_scope(
304d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower      None, 'weights',
305d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower      values=tuple(six.itervalues(features)) + (logits,)) as scope:
306d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower    # Fetch the weights.
307d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower    if weight_column is None:
308d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower      return 1.
309d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower    if isinstance(weight_column, six.string_types):
310d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower      weight_column = feature_column_lib.numeric_column(
311d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower          key=weight_column, shape=(1,))
312d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower    if not isinstance(weight_column, feature_column_lib._NumericColumn):  # pylint: disable=protected-access
313d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower      raise TypeError('Weight column must be either a string or _NumericColumn.'
314d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower                      ' Given type: {}.'.format(type(weight_column)))
315d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower    weights = weight_column._get_dense_tensor(  # pylint: disable=protected-access
316d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower        feature_column_lib._LazyBuilder(features))  # pylint: disable=protected-access
317d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower    if not (weights.dtype.is_floating or weights.dtype.is_integer):
318d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower      raise ValueError('Weight column should be castable to float. '
319d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower                       'Given dtype: {}'.format(weights.dtype))
320d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower    weights = math_ops.to_float(weights, name='weights')
321d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower
322d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower    # Validate the weights shape.
32367c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower    weights_shape = array_ops.shape(weights, name='weights_shape')
32467c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower    logits_shape = array_ops.shape(logits, name='logits_shape')
32567c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower    if (weights.shape.ndims is not None and logits.shape.ndims is not None and
32667c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower        weights.shape.ndims == logits.shape.ndims - 1):
32767c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower      assert_dimension = check_ops.assert_equal(
32867c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower          logits_shape[:-1], weights_shape, message=err_msg,
32967c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower          data=['logits_shape: ', logits_shape,
33067c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower                'weights_shape: ', weights_shape])
33167c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower      with ops.control_dependencies([assert_dimension]):
33267c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower        return array_ops.expand_dims(weights, -1, name=scope)
33367c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower    supported_weights_shape = array_ops.concat([logits_shape[:-1], [1]], axis=0)
334d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower    if allow_per_logit_weights:
335d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower      condition = math_ops.reduce_any(
336d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower          [math_ops.reduce_all(math_ops.equal(logits_shape, weights_shape)),
337d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower           math_ops.reduce_all(math_ops.equal(
338d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower               supported_weights_shape, weights_shape))])
339d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower      assert_dimension = control_flow_ops.Assert(
340d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower          condition=condition,
341d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower          data=[err_msg, 'logits_shape: ', logits_shape,
342d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower                'weights_shape: ', weights_shape])
343d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower    else:
344d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower      assert_dimension = check_ops.assert_equal(
345d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower          supported_weights_shape, weights_shape, message=err_msg,
346d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower          data=['logits_shape: ', logits_shape,
347d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower                'weights_shape: ', weights_shape])
34867c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower    with ops.control_dependencies([assert_dimension]):
34967c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower      return array_ops.identity(weights, name=scope)
35067c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower
35167c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower
35267c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlowerdef _check_logits_final_dim(logits, expected_logits_dimension):
35367c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower  """Checks that logits shape is [D0, D1, ... DN, logits_dimension]."""
35467c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower  with ops.name_scope(None, 'logits', (logits,)) as scope:
35567c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower    logits = math_ops.to_float(logits)
35667c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower    logits_shape = array_ops.shape(logits)
35767c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower    assert_rank = check_ops.assert_rank_at_least(
35867c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower        logits, 2, data=[logits_shape],
35967c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower        message='logits shape must be [D0, D1, ... DN, logits_dimension]')
36067c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower    with ops.control_dependencies([assert_rank]):
36167c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower      static_shape = logits.shape
36267c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower      if static_shape.ndims is not None and static_shape[-1] is not None:
36367c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower        if static_shape[-1] != expected_logits_dimension:
36467c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower          raise ValueError(
36567c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower              'logits shape must be [D0, D1, ... DN, logits_dimension], '
36667c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower              'got %s.' % (static_shape,))
36767c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower        return logits
36867c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower      assert_dimension = check_ops.assert_equal(
36967c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower          expected_logits_dimension, logits_shape[-1], data=[logits_shape],
37067c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower          message='logits shape must be [D0, D1, ... DN, logits_dimension]')
37167c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower      with ops.control_dependencies([assert_dimension]):
37267c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower        return array_ops.identity(logits, name=scope)
37367c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower
37467c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower
3753942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlowerdef _validate_loss_fn_args(loss_fn):
3763942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower  """Validates loss_fn arguments.
3773942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower
3783942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower  Required arguments: labels, logits.
3793942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower  Optional arguments: features.
3803942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower
3813942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower  Args:
3823942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower    loss_fn: The loss function.
3833942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower  Raises:
3843942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower    ValueError: If the signature is unexpected.
3853942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower  """
3863942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower  loss_fn_args = util.fn_args(loss_fn)
3873942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower  for required_arg in ['labels', 'logits']:
3883942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower    if required_arg not in loss_fn_args:
3893942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower      raise ValueError(
3903942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower          'loss_fn must contain argument: {}. '
3913942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower          'Given arguments: {}'.format(required_arg, loss_fn_args))
3923942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower  invalid_args = list(set(loss_fn_args) - set(['labels', 'logits', 'features']))
3933942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower  if invalid_args:
3943942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower    raise ValueError('loss_fn has unexpected args: {}'.format(invalid_args))
3953942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower
3963942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower
3973942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlowerdef _call_loss_fn(loss_fn, labels, logits, features, expected_loss_dim=1):
3983942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower  """Calls loss_fn and checks the returned shape.
3993942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower
4003942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower  Args:
4013942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower    loss_fn: The loss function.
4023942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower    labels: Processed labels Tensor.
4033942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower    logits: Logits Tensor of shape [D0, D1, ... DN, logits_dimension].
4043942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower    features: Features dict.
4053942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower    expected_loss_dim: The expected last dimension of loss Tensor.
4063942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower  Returns:
4073942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower    Loss Tensor with shape [D0, D1, ... DN, expected_loss_dim].
4083942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower  """
4093942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower  loss_fn_args = util.fn_args(loss_fn)
4103942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower  kwargs = {}
4113942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower  if 'features' in loss_fn_args:
4123942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower    kwargs['features'] = features
4133942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower  with ops.name_scope(
4143942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower      None, 'call_loss_fn',
4153942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower      values=[labels, logits] + list(six.itervalues(features))):
4163942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower    unweighted_loss = loss_fn(labels=labels, logits=logits, **kwargs)
4173942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower    logits_shape = array_ops.shape(logits, name='logits_shape')
4183942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower    expected_loss_shape = array_ops.concat(
4193942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower        [logits_shape[:-1], [expected_loss_dim]], axis=0,
4203942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower        name='expected_loss_shape')
4213942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower    loss_shape = array_ops.shape(unweighted_loss, name='loss_shape')
4223942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower    check_loss_shape_op = control_flow_ops.Assert(
4233942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower        math_ops.reduce_all(math_ops.equal(loss_shape, expected_loss_shape)),
4243942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower        data=[
4253942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower            'loss_fn must return Tensor of shape '
4263942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower            '[D0, D1, ... DN, {}]. '.format(expected_loss_dim),
4273942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower            'logits_shape: ', logits_shape, 'loss_shape: ', loss_shape],
4283942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower        name='check_loss_shape')
4293942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower    with ops.control_dependencies([check_loss_shape_op]):
4303942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower      return array_ops.identity(unweighted_loss)
4313942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower
4323942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower
43334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerdef _indicator_labels_mean(labels, weights=None, name=None):
43434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  with ops.name_scope(name, 'labels_mean', (labels, weights)) as scope:
43534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    labels = math_ops.to_float(labels, name='labels')
43634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    if weights is not None:
43734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      weights = weights_broadcast_ops.broadcast_weights(weights, labels)
43834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    return metrics_lib.mean(labels, weights=weights, name=scope)
43934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
44034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
441df22cf83a21b62838ecf6f3a1c8a9c30ab20d482A. Unique TensorFlowerdef _classification_output(scores, n_classes, label_vocabulary=None):
442df22cf83a21b62838ecf6f3a1c8a9c30ab20d482A. Unique TensorFlower  batch_size = array_ops.shape(scores)[0]
443df22cf83a21b62838ecf6f3a1c8a9c30ab20d482A. Unique TensorFlower  if label_vocabulary:
444df22cf83a21b62838ecf6f3a1c8a9c30ab20d482A. Unique TensorFlower    export_class_list = label_vocabulary
445df22cf83a21b62838ecf6f3a1c8a9c30ab20d482A. Unique TensorFlower  else:
446df22cf83a21b62838ecf6f3a1c8a9c30ab20d482A. Unique TensorFlower    export_class_list = string_ops.as_string(math_ops.range(n_classes))
447df22cf83a21b62838ecf6f3a1c8a9c30ab20d482A. Unique TensorFlower  export_output_classes = array_ops.tile(
448df22cf83a21b62838ecf6f3a1c8a9c30ab20d482A. Unique TensorFlower      input=array_ops.expand_dims(input=export_class_list, axis=0),
449df22cf83a21b62838ecf6f3a1c8a9c30ab20d482A. Unique TensorFlower      multiples=[batch_size, 1])
450df22cf83a21b62838ecf6f3a1c8a9c30ab20d482A. Unique TensorFlower  return export_output.ClassificationOutput(
451df22cf83a21b62838ecf6f3a1c8a9c30ab20d482A. Unique TensorFlower      scores=scores,
452df22cf83a21b62838ecf6f3a1c8a9c30ab20d482A. Unique TensorFlower      # `ClassificationOutput` requires string classes.
453df22cf83a21b62838ecf6f3a1c8a9c30ab20d482A. Unique TensorFlower      classes=export_output_classes)
454df22cf83a21b62838ecf6f3a1c8a9c30ab20d482A. Unique TensorFlower
455df22cf83a21b62838ecf6f3a1c8a9c30ab20d482A. Unique TensorFlower
45634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerdef _accuracy_baseline(labels_mean):
45734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  """Return accuracy baseline based on labels mean.
45834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
45934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  This is the best the model could do by always predicting one class.
46034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
46134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  Args:
46234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    labels_mean: Tuple of value and update op.
46334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
46434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  Returns:
46534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    Tuple of value and update op.
46634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  """
46734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  with ops.name_scope(None, 'accuracy_baseline', labels_mean):
46834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    value, update_op = labels_mean
46934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    return (
47034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        math_ops.maximum(value, 1. - value, name='value'),
47134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        math_ops.maximum(update_op, 1 - update_op, name='update_op'))
47234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
47334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
47434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerdef _predictions_mean(predictions, weights=None, name=None):
47534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  with ops.name_scope(
47634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      name, 'predictions_mean', (predictions, weights)) as scope:
47734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    predictions = math_ops.to_float(predictions, name='predictions')
47834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    if weights is not None:
47934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      weights = weights_broadcast_ops.broadcast_weights(weights, predictions)
48034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    return metrics_lib.mean(predictions, weights=weights, name=scope)
48134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
48234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
48334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerdef _auc(labels, predictions, weights=None, curve='ROC', name=None):
48434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  with ops.name_scope(name, 'auc', (predictions, labels, weights)) as scope:
48534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    predictions = math_ops.to_float(predictions, name='predictions')
48634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    if weights is not None:
48734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      weights = weights_broadcast_ops.broadcast_weights(weights, predictions)
48834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    return metrics_lib.auc(
48934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        labels=labels, predictions=predictions, weights=weights, curve=curve,
49034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        name=scope)
49134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
49234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
49334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerdef _accuracy_at_threshold(labels, predictions, weights, threshold, name=None):
49434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  with ops.name_scope(
49534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      name, 'accuracy_at_%s' % threshold,
49634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      (predictions, labels, weights, threshold)) as scope:
49734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    threshold_predictions = math_ops.to_float(
49834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        math_ops.greater_equal(predictions, threshold))
49934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    return metrics_lib.accuracy(
50034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        labels=labels, predictions=threshold_predictions, weights=weights,
50134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        name=scope)
50234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
50334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
50434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerdef _precision_at_threshold(labels, predictions, weights, threshold, name=None):
50534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  with ops.name_scope(
50634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      name, 'precision_at_%s' % threshold,
50734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      (predictions, labels, weights, threshold)) as scope:
50834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    precision_tensor, update_op = metrics_lib.precision_at_thresholds(
50934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        labels=labels, predictions=predictions, thresholds=(threshold,),
51034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        weights=weights, name=scope)
51134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    return array_ops.squeeze(precision_tensor), array_ops.squeeze(update_op)
51234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
51334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
51434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerdef _recall_at_threshold(labels, predictions, weights, threshold, name=None):
51534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  with ops.name_scope(
51634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      name, 'recall_at_%s' % threshold,
51734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      (predictions, labels, weights, threshold)) as scope:
51834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    precision_tensor, update_op = metrics_lib.recall_at_thresholds(
51934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        labels=labels, predictions=predictions, thresholds=(threshold,),
52034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        weights=weights, name=scope)
52134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    return array_ops.squeeze(precision_tensor), array_ops.squeeze(update_op)
52234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
52334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
52420d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlowerdef _multi_class_head_with_softmax_cross_entropy_loss(
52520d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower    n_classes,
52620d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower    weight_column=None,
52720d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower    label_vocabulary=None,
52820d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower    loss_reduction=losses.Reduction.SUM,
5293942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower    loss_fn=None,
53020d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower    name=None):
53134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  """Creates a '_Head' for multi class classification.
53234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
5337948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower  The head expects `logits` with shape `[D0, D1, ... DN, n_classes]`.
5347948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower  In many applications, the shape is `[batch_size, n_classes]`.
5357948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower
5367948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower  `labels` must be a dense `Tensor` with shape matching `logits`, namely
5377948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower  `[D0, D1, ... DN, 1]`. If `label_vocabulary` given, `labels` must be a string
5387948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower  `Tensor` with values from the vocabulary. If `label_vocabulary` is not given,
5397948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower  `labels` must be an integer `Tensor` with values specifying the class index.
5407948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower
5417948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower  If `weight_column` is specified, weights must be of shape
5427948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower  `[D0, D1, ... DN]`, or `[D0, D1, ... DN, 1]`.
5437948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower
5447948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower  The loss is the weighted sum over the input dimensions. Namely, if the input
5457948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower  labels have shape `[batch_size, 1]`, the loss is the weighted sum over
5467948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower  `batch_size`.
54734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
5483942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower  Also supports custom `loss_fn`. `loss_fn` takes `(labels, logits)` or
5493942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower  `(labels, logits, features)` as arguments and returns unreduced loss with
5503942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower  shape `[D0, D1, ... DN, 1]`. `loss_fn` must support integer `labels` with
5513942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower  shape `[D0, D1, ... DN, 1]`. Namely, the head applies `label_vocabulary` to
5523942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower  the input labels before passing them to `loss_fn`.
5533942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower
55434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  Args:
55534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    n_classes: Number of classes, must be greater than 2 (for 2 classes, use
55634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      `_BinaryLogisticHeadWithSigmoidCrossEntropyLoss`).
557d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispir    weight_column: A string or a `_NumericColumn` created by
558d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispir      `tf.feature_column.numeric_column` defining feature column representing
55934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      weights. It is used to down weight or boost examples during training. It
56034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      will be multiplied by the loss of the example.
5619c40507f80434058f600ebebb8b9d6971dd0bdb4Petros Mol    label_vocabulary: A list or tuple of strings representing possible label
5629c40507f80434058f600ebebb8b9d6971dd0bdb4Petros Mol      values. If it is not given, that means labels are already encoded as an
5639c40507f80434058f600ebebb8b9d6971dd0bdb4Petros Mol      integer within [0, n_classes). If given, labels must be of string type and
5649c40507f80434058f600ebebb8b9d6971dd0bdb4Petros Mol      have any value in `label_vocabulary`. Note that errors will be raised if
5659c40507f80434058f600ebebb8b9d6971dd0bdb4Petros Mol      `label_vocabulary` is not provided but labels are strings.
56620d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower    loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to
56720d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower      reduce training loss over batch. Defaults to `SUM`.
5683942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower    loss_fn: Optional loss function.
569abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower    name: name of the head. If provided, summary and metrics keys will be
57001c76110eb3cb1c378c9d7a14ca9f838bad6c7d1A. Unique TensorFlower      suffixed by `"/" + name`. Also used as `name_scope` when creating ops.
57134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
57234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  Returns:
5730815de21239955e346b562e899640649c8d2b9cbBenoit Steiner    An instance of `_Head` for multi class classification.
57434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
57534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  Raises:
57620d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower    ValueError: If `n_classes`, `label_vocabulary` or `loss_reduction` is
57720d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower      invalid.
57834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  """
579edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir  if label_vocabulary is not None and not isinstance(label_vocabulary,
580edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir                                                     (list, tuple)):
5819c40507f80434058f600ebebb8b9d6971dd0bdb4Petros Mol    raise ValueError(
5829c40507f80434058f600ebebb8b9d6971dd0bdb4Petros Mol        'label_vocabulary should be a list or a tuple. Given type: {}'.format(
5839c40507f80434058f600ebebb8b9d6971dd0bdb4Petros Mol            type(label_vocabulary)))
58420d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower  if (loss_reduction not in losses.Reduction.all() or
58520d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower      loss_reduction == losses.Reduction.NONE):
58620d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower    raise ValueError('Invalid loss_reduction: {}'.format(loss_reduction))
5873942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower  if loss_fn:
5883942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower    _validate_loss_fn_args(loss_fn)
58920d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower  return _MultiClassHeadWithSoftmaxCrossEntropyLoss(
59020d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower      n_classes=n_classes,
59120d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower      weight_column=weight_column,
59220d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower      label_vocabulary=label_vocabulary,
59320d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower      loss_reduction=loss_reduction,
5943942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower      loss_fn=loss_fn,
59520d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower      name=name)
59634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
59734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
59834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerclass _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head):
59934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  """See `_multi_class_head_with_softmax_cross_entropy_loss`."""
60034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
6010fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower  def __init__(self,
6020fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower               n_classes,
6030fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower               weight_column=None,
6040fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower               label_vocabulary=None,
60520d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower               loss_reduction=losses.Reduction.SUM,
6063942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower               loss_fn=None,
607abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower               name=None):
60834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    if (n_classes is None) or (n_classes <= 2):
60934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      raise ValueError('n_classes must be > 2: %s.' % n_classes)
61034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    self._n_classes = n_classes
611d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispir    self._weight_column = weight_column
612ce21f5e372360451528c9716a743b65308421423Mustafa Ispir    self._label_vocabulary = label_vocabulary
61320d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower    self._loss_reduction = loss_reduction
6143942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower    self._loss_fn = loss_fn
615abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower    self._name = name
616abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower
617abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower  @property
618abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower  def name(self):
619abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower    return self._name
62034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
62134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  @property
62234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  def logits_dimension(self):
62334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    return self._n_classes
62434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
625c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower  def _eval_metric_ops(
626c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower      self, labels, class_ids, weights, unreduced_loss, regularization_loss):
62734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    """Returns the Eval metric ops."""
62834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    with ops.name_scope(
629c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower        None, 'metrics',
630c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower        (labels, class_ids, weights, unreduced_loss, regularization_loss)):
63134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      keys = metric_keys.MetricKeys
63234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      metric_ops = {
63334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower          # Estimator already adds a metric for loss.
63434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower          # TODO(xiejw): Any other metrics?
635abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower          _summary_key(self._name, keys.LOSS_MEAN):
6360fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower              metrics_lib.mean(
63720d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower                  values=unreduced_loss,
63820d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower                  weights=weights,
63999d51b8da87ba462f66a6be90212677f2cae9e32A. Unique TensorFlower                  name=keys.LOSS_MEAN),
640abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower          _summary_key(self._name, keys.ACCURACY):
6410fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower              metrics_lib.accuracy(
6420fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower                  labels=labels,
6430fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower                  predictions=class_ids,
6440fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower                  weights=weights,
6450fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower                  name=keys.ACCURACY),
64634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      }
647c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower      if regularization_loss is not None:
648c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower        metric_ops[_summary_key(self._name, keys.LOSS_REGULARIZATION)] = (
649c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower            metrics_lib.mean(
650c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower                values=regularization_loss,
651c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower                name=keys.LOSS_REGULARIZATION))
65234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    return metric_ops
65334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
654ce21f5e372360451528c9716a743b65308421423Mustafa Ispir  def _label_ids(self, labels):
655ce21f5e372360451528c9716a743b65308421423Mustafa Ispir    """Converts labels to integer id space."""
656ce21f5e372360451528c9716a743b65308421423Mustafa Ispir    if self._label_vocabulary is None:
657ce21f5e372360451528c9716a743b65308421423Mustafa Ispir      if not labels.dtype.is_integer:
6589c40507f80434058f600ebebb8b9d6971dd0bdb4Petros Mol        raise ValueError('Labels dtype should be integer. Instead got {}.'.
6599c40507f80434058f600ebebb8b9d6971dd0bdb4Petros Mol                         format(labels.dtype))
660ce21f5e372360451528c9716a743b65308421423Mustafa Ispir      label_ids = labels
661ce21f5e372360451528c9716a743b65308421423Mustafa Ispir    else:
662ce21f5e372360451528c9716a743b65308421423Mustafa Ispir      if labels.dtype != dtypes.string:
663ce21f5e372360451528c9716a743b65308421423Mustafa Ispir        raise ValueError('Labels dtype should be string if there is a '
664ce21f5e372360451528c9716a743b65308421423Mustafa Ispir                         'vocabulary. Instead got {}'.format(labels.dtype))
665ce21f5e372360451528c9716a743b65308421423Mustafa Ispir      label_ids = lookup_ops.index_table_from_tensor(
666ce21f5e372360451528c9716a743b65308421423Mustafa Ispir          vocabulary_list=tuple(self._label_vocabulary),
667ce21f5e372360451528c9716a743b65308421423Mustafa Ispir          name='class_id_lookup').lookup(labels)
668edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir    return _assert_range(label_ids, self._n_classes)
669ce21f5e372360451528c9716a743b65308421423Mustafa Ispir
6708d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower  def create_loss(self, features, mode, logits, labels):
6718d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower    """See `Head`."""
67299d51b8da87ba462f66a6be90212677f2cae9e32A. Unique TensorFlower    del mode  # Unused for this head.
6737948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower    logits = ops.convert_to_tensor(logits)
6747948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower    labels = _check_dense_labels_match_logits_and_reshape(
6757948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower        labels=labels, logits=logits, expected_labels_dimension=1)
6767948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower    label_ids = self._label_ids(labels)
6773942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower    if self._loss_fn:
6783942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower      unweighted_loss = _call_loss_fn(
6793942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower          loss_fn=self._loss_fn, labels=label_ids, logits=logits,
6803942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower          features=features, expected_loss_dim=1)
6813942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower    else:
6823942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower      unweighted_loss = losses.sparse_softmax_cross_entropy(
6833942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower          labels=label_ids, logits=logits, reduction=losses.Reduction.NONE)
6843942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower      # Restore the squeezed dim, so unweighted_loss matches the weights shape.
6853942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower      unweighted_loss = array_ops.expand_dims(unweighted_loss, axis=-1)
686d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower    weights = _get_weights_and_check_match_logits(
687d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower        features=features, weight_column=self._weight_column, logits=logits)
68820d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower    training_loss = losses.compute_weighted_loss(
68920d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower        unweighted_loss, weights=weights, reduction=self._loss_reduction)
69099d51b8da87ba462f66a6be90212677f2cae9e32A. Unique TensorFlower    return LossSpec(
69120d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower        training_loss=training_loss,
69220d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower        unreduced_loss=unweighted_loss,
69320d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower        weights=weights,
6948d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower        processed_labels=label_ids)
6958d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower
69634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  def create_estimator_spec(
697c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower      self, features, mode, logits, labels=None, train_op_fn=None,
698c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower      regularization_losses=None):
6997948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower    """Returns an `EstimatorSpec`.
7007948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower
7017948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower    Args:
7027948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower      features: Input `dict` of `Tensor` or `SparseTensor` objects.
7037948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower      mode: Estimator's `ModeKeys`.
7047948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower      logits: logits `Tensor` with shape `[D0, D1, ... DN, logits_dimension]`.
7057948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower        For many applications, the shape is `[batch_size, logits_dimension]`.
7067948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower      labels: Labels integer or string `Tensor` with shape matching `logits`,
7073f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower        namely `[D0, D1, ... DN, 1]` or `[D0, D1, ... DN]`. `labels` is
7083f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower        required argument when `mode` equals `TRAIN` or `EVAL`.
7097948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower      train_op_fn: Function that takes a scalar loss `Tensor` and returns
7107948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower        `train_op`. Required in TRAIN mode.
711c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower      regularization_losses: A list of additional scalar losses to be added to
712c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower        the training loss, such as regularization losses. These losses are
713c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower        usually expressed as a batch average, so for best results users need to
7143f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower        set `loss_reduction=SUM_OVER_BATCH_SIZE` or
7153f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower        `loss_reduction=SUM_OVER_NONZERO_WEIGHTS` when creating the head to
716c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower        avoid scaling errors.
7177948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower    Returns:
7187948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower      `EstimatorSpec`.
7197948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower    Raises:
7207948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower      ValueError: If `train_op_fn` is `None` in TRAIN mode.
7217948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower    """
72201c76110eb3cb1c378c9d7a14ca9f838bad6c7d1A. Unique TensorFlower    with ops.name_scope(self._name, 'head'):
7237948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower      logits = _check_logits_final_dim(logits, self.logits_dimension)
72434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
72534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      # Predict.
72634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      pred_keys = prediction_keys.PredictionKeys
72734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      with ops.name_scope(None, 'predictions', (logits,)):
7287948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower        # class_ids's shape is [D0, D1, ... DN].
7297948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower        class_ids = math_ops.argmax(logits, axis=-1, name=pred_keys.CLASS_IDS)
7307948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower        class_ids = array_ops.expand_dims(class_ids, axis=-1)
731ce21f5e372360451528c9716a743b65308421423Mustafa Ispir        if self._label_vocabulary:
732ce21f5e372360451528c9716a743b65308421423Mustafa Ispir          table = lookup_ops.index_to_string_table_from_tensor(
733ce21f5e372360451528c9716a743b65308421423Mustafa Ispir              vocabulary_list=self._label_vocabulary,
734ce21f5e372360451528c9716a743b65308421423Mustafa Ispir              name='class_string_lookup')
735ce21f5e372360451528c9716a743b65308421423Mustafa Ispir          classes = table.lookup(class_ids)
736ce21f5e372360451528c9716a743b65308421423Mustafa Ispir        else:
737ce21f5e372360451528c9716a743b65308421423Mustafa Ispir          classes = string_ops.as_string(class_ids, name='str_classes')
738ce21f5e372360451528c9716a743b65308421423Mustafa Ispir
73934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        probabilities = nn.softmax(logits, name=pred_keys.PROBABILITIES)
74034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        predictions = {
74134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower            pred_keys.LOGITS: logits,
74234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower            pred_keys.PROBABILITIES: probabilities,
74334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower            # Expand to [batch_size, 1]
744ce21f5e372360451528c9716a743b65308421423Mustafa Ispir            pred_keys.CLASS_IDS: class_ids,
745ce21f5e372360451528c9716a743b65308421423Mustafa Ispir            pred_keys.CLASSES: classes,
74634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        }
74734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      if mode == model_fn.ModeKeys.PREDICT:
748df22cf83a21b62838ecf6f3a1c8a9c30ab20d482A. Unique TensorFlower        classifier_output = _classification_output(
749df22cf83a21b62838ecf6f3a1c8a9c30ab20d482A. Unique TensorFlower            scores=probabilities, n_classes=self._n_classes,
750df22cf83a21b62838ecf6f3a1c8a9c30ab20d482A. Unique TensorFlower            label_vocabulary=self._label_vocabulary)
75134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        return model_fn.EstimatorSpec(
75234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower            mode=model_fn.ModeKeys.PREDICT,
75334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower            predictions=predictions,
754ce21f5e372360451528c9716a743b65308421423Mustafa Ispir            export_outputs={
7554db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel                _DEFAULT_SERVING_KEY: classifier_output,
7564db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel                _CLASSIFY_SERVING_KEY: classifier_output,
7574db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel                _PREDICT_SERVING_KEY: export_output.PredictOutput(predictions)
758ce21f5e372360451528c9716a743b65308421423Mustafa Ispir            })
75934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
76020d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower      training_loss, unreduced_loss, weights, label_ids = self.create_loss(
7618d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower          features=features, mode=mode, logits=logits, labels=labels)
762c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower      if regularization_losses:
763c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower        regularization_loss = math_ops.add_n(regularization_losses)
764c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower        regularized_training_loss = math_ops.add_n(
765c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower            [training_loss, regularization_loss])
766c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower      else:
767c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower        regularization_loss = None
768c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower        regularized_training_loss = training_loss
76999d51b8da87ba462f66a6be90212677f2cae9e32A. Unique TensorFlower      # Eval.
77034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      if mode == model_fn.ModeKeys.EVAL:
77134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        return model_fn.EstimatorSpec(
77234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower            mode=model_fn.ModeKeys.EVAL,
77334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower            predictions=predictions,
774c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower            loss=regularized_training_loss,
77534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower            eval_metric_ops=self._eval_metric_ops(
776ce21f5e372360451528c9716a743b65308421423Mustafa Ispir                labels=label_ids,
77734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower                class_ids=class_ids,
77820d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower                weights=weights,
779c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower                unreduced_loss=unreduced_loss,
780c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower                regularization_loss=regularization_loss))
78134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
78234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      # Train.
78334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      if train_op_fn is None:
7849c40507f80434058f600ebebb8b9d6971dd0bdb4Petros Mol        raise ValueError('train_op_fn cannot be None.')
78520d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower      # Only summarize mean_loss for SUM reduction to preserve backwards
78620d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower      # compatibility. Otherwise skip it to avoid unnecessary computation.
78720d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower      if self._loss_reduction == losses.Reduction.SUM:
78820d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower        example_weight_sum = math_ops.reduce_sum(
78920d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower            weights * array_ops.ones_like(unreduced_loss))
79020d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower        mean_loss = training_loss / example_weight_sum
79120d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower      else:
79220d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower        mean_loss = None
793169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir    with ops.name_scope(''):
794c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower      keys = metric_keys.MetricKeys
7950fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower      summary.scalar(
796c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower          _summary_key(self._name, keys.LOSS),
797c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower          regularized_training_loss)
79820d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower      if mean_loss is not None:
79920d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower        summary.scalar(
800c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower            _summary_key(self._name, keys.LOSS_MEAN),
80120d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower            mean_loss)
802c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower      if regularization_loss is not None:
803c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower        summary.scalar(
804c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower            _summary_key(self._name, keys.LOSS_REGULARIZATION),
805c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower            regularization_loss)
806169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir    return model_fn.EstimatorSpec(
807169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir        mode=model_fn.ModeKeys.TRAIN,
808169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir        predictions=predictions,
809c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower        loss=regularized_training_loss,
810c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower        train_op=train_op_fn(regularized_training_loss))
81134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
81234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
81334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerdef _binary_logistic_head_with_sigmoid_cross_entropy_loss(
8143942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower    weight_column=None,
8153942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower    thresholds=None,
8163942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower    label_vocabulary=None,
8173942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower    loss_reduction=losses.Reduction.SUM,
8183942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower    loss_fn=None,
8193942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower    name=None):
820d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower  """Creates a `_Head` for single label binary classification.
82134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
82234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  This head uses `sigmoid_cross_entropy_with_logits` loss.
82334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
824b992761e569ac505e156127b64b91eba2f0953cbA. Unique TensorFlower  The head expects `logits` with shape `[D0, D1, ... DN, 1]`.
825b992761e569ac505e156127b64b91eba2f0953cbA. Unique TensorFlower  In many applications, the shape is `[batch_size, 1]`.
826b992761e569ac505e156127b64b91eba2f0953cbA. Unique TensorFlower
827b992761e569ac505e156127b64b91eba2f0953cbA. Unique TensorFlower  `labels` must be a dense `Tensor` with shape matching `logits`, namely
828b992761e569ac505e156127b64b91eba2f0953cbA. Unique TensorFlower  `[D0, D1, ... DN, 1]`. If `label_vocabulary` given, `labels` must be a string
829b992761e569ac505e156127b64b91eba2f0953cbA. Unique TensorFlower  `Tensor` with values from the vocabulary. If `label_vocabulary` is not given,
830b992761e569ac505e156127b64b91eba2f0953cbA. Unique TensorFlower  `labels` must be float `Tensor` with values in the interval `[0, 1]`.
831b992761e569ac505e156127b64b91eba2f0953cbA. Unique TensorFlower
832b992761e569ac505e156127b64b91eba2f0953cbA. Unique TensorFlower  If `weight_column` is specified, weights must be of shape
833b992761e569ac505e156127b64b91eba2f0953cbA. Unique TensorFlower  `[D0, D1, ... DN]`, or `[D0, D1, ... DN, 1]`.
834b992761e569ac505e156127b64b91eba2f0953cbA. Unique TensorFlower
835b992761e569ac505e156127b64b91eba2f0953cbA. Unique TensorFlower  The loss is the weighted sum over the input dimensions. Namely, if the input
836b992761e569ac505e156127b64b91eba2f0953cbA. Unique TensorFlower  labels have shape `[batch_size, 1]`, the loss is the weighted sum over
837b992761e569ac505e156127b64b91eba2f0953cbA. Unique TensorFlower  `batch_size`.
83834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
8393942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower  Also supports custom `loss_fn`. `loss_fn` takes `(labels, logits)` or
8403942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower  `(labels, logits, features)` as arguments and returns unreduced loss with
8413942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower  shape `[D0, D1, ... DN, 1]`. `loss_fn` must support float `labels` with
8423942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower  shape `[D0, D1, ... DN, 1]`. Namely, the head applies `label_vocabulary` to
8433942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower  the input labels before passing them to `loss_fn`.
8443942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower
84534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  Args:
846d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispir    weight_column: A string or a `_NumericColumn` created by
847d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispir      `tf.feature_column.numeric_column` defining feature column representing
84834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      weights. It is used to down weight or boost examples during training. It
84934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      will be multiplied by the loss of the example.
85034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    thresholds: Iterable of floats in the range `(0, 1)`. For binary
85134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      classification metrics such as precision and recall, an eval metric is
85234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      generated for each threshold value. This threshold is applied to the
85334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      logistic values to determine the binary classification (i.e., above the
85434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      threshold is `true`, below is `false`.
8559c40507f80434058f600ebebb8b9d6971dd0bdb4Petros Mol    label_vocabulary: A list or tuple of strings representing possible label
8569c40507f80434058f600ebebb8b9d6971dd0bdb4Petros Mol      values. If it is not given, that means labels are already encoded within
8579c40507f80434058f600ebebb8b9d6971dd0bdb4Petros Mol      [0, 1]. If given, labels must be string type and have any value in
8589c40507f80434058f600ebebb8b9d6971dd0bdb4Petros Mol      `label_vocabulary`. Note that errors will be raised if `label_vocabulary`
8599c40507f80434058f600ebebb8b9d6971dd0bdb4Petros Mol      is not provided but labels are strings.
86020d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower    loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to
86120d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower      reduce training loss over batch. Defaults to `SUM`.
8623942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower    loss_fn: Optional loss function.
863abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower    name: name of the head. If provided, summary and metrics keys will be
86401c76110eb3cb1c378c9d7a14ca9f838bad6c7d1A. Unique TensorFlower      suffixed by `"/" + name`. Also used as `name_scope` when creating ops.
86534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
86634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  Returns:
867d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower    An instance of `_Head` for binary classification.
86834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
86934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  Raises:
87020d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower    ValueError: If `thresholds` contains a value outside of `(0, 1)`.
87120d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower    ValueError: If `loss_reduction` is invalid.
87234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  """
87379099d67761b3e56d1c3764cd34f97401571a211A. Unique TensorFlower  thresholds = tuple(thresholds) if thresholds else tuple()
874edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir  if label_vocabulary is not None and not isinstance(label_vocabulary,
875edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir                                                     (list, tuple)):
8769c40507f80434058f600ebebb8b9d6971dd0bdb4Petros Mol    raise ValueError(
8779c40507f80434058f600ebebb8b9d6971dd0bdb4Petros Mol        'label_vocabulary should be a list or tuple. Given type: {}'.format(
8789c40507f80434058f600ebebb8b9d6971dd0bdb4Petros Mol            type(label_vocabulary)))
879edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir
88034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  for threshold in thresholds:
88134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    if (threshold <= 0.0) or (threshold >= 1.0):
8829c40507f80434058f600ebebb8b9d6971dd0bdb4Petros Mol      raise ValueError('thresholds not in (0, 1): {}.'.format((thresholds,)))
88320d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower  if (loss_reduction not in losses.Reduction.all() or
88420d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower      loss_reduction == losses.Reduction.NONE):
88520d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower    raise ValueError('Invalid loss_reduction: {}'.format(loss_reduction))
8863942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower  if loss_fn:
8873942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower    _validate_loss_fn_args(loss_fn)
88834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  return _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(
889d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispir      weight_column=weight_column,
890edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir      thresholds=thresholds,
8910fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower      label_vocabulary=label_vocabulary,
89220d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower      loss_reduction=loss_reduction,
8933942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower      loss_fn=loss_fn,
894abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower      name=name)
89534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
89634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
89734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerclass _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head):
89834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  """See `_binary_logistic_head_with_sigmoid_cross_entropy_loss`."""
89934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
9000fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower  def __init__(self,
9010fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower               weight_column=None,
9020fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower               thresholds=None,
9030fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower               label_vocabulary=None,
90420d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower               loss_reduction=losses.Reduction.SUM,
9053942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower               loss_fn=None,
906abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower               name=None):
907d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispir    self._weight_column = weight_column
90879099d67761b3e56d1c3764cd34f97401571a211A. Unique TensorFlower    self._thresholds = thresholds
909edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir    self._label_vocabulary = label_vocabulary
91020d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower    self._loss_reduction = loss_reduction
9113942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower    self._loss_fn = loss_fn
912abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower    self._name = name
913abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower
914abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower  @property
915abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower  def name(self):
916abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower    return self._name
91734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
91834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  @property
91934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  def logits_dimension(self):
92034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    return 1
92134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
92299d51b8da87ba462f66a6be90212677f2cae9e32A. Unique TensorFlower  def _eval_metric_ops(self, labels, logits, logistic, class_ids, weights,
9233f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower                       unreduced_loss, regularization_loss):
92499d51b8da87ba462f66a6be90212677f2cae9e32A. Unique TensorFlower    with ops.name_scope(None, 'metrics',
92599d51b8da87ba462f66a6be90212677f2cae9e32A. Unique TensorFlower                        (labels, logits, logistic, class_ids, weights,
9263f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower                         unreduced_loss, regularization_loss)):
92734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      keys = metric_keys.MetricKeys
92834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      labels_mean = _indicator_labels_mean(
92934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower          labels=labels, weights=weights, name=keys.LABEL_MEAN)
93034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      metric_ops = {
93134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower          # Estimator already adds a metric for loss.
932abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower          _summary_key(self._name, keys.LOSS_MEAN):
933edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir              metrics_lib.mean(
93420d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower                  values=unreduced_loss,
93520d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower                  weights=weights,
93699d51b8da87ba462f66a6be90212677f2cae9e32A. Unique TensorFlower                  name=keys.LOSS_MEAN),
937abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower          _summary_key(self._name, keys.ACCURACY):
938edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir              metrics_lib.accuracy(
939edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir                  labels=labels,
940edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir                  predictions=class_ids,
941edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir                  weights=weights,
942edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir                  name=keys.ACCURACY),
943abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower          _summary_key(self._name, keys.PREDICTION_MEAN):
944edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir              _predictions_mean(
945edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir                  predictions=logistic,
946edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir                  weights=weights,
947edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir                  name=keys.PREDICTION_MEAN),
948abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower          _summary_key(self._name, keys.LABEL_MEAN):
949edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir              labels_mean,
950abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower          _summary_key(self._name, keys.ACCURACY_BASELINE):
951edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir              _accuracy_baseline(labels_mean),
952abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower          _summary_key(self._name, keys.AUC):
953edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir              _auc(
954edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir                  labels=labels,
955edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir                  predictions=logistic,
956edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir                  weights=weights,
957edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir                  name=keys.AUC),
958abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower          _summary_key(self._name, keys.AUC_PR):
959edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir              _auc(
960edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir                  labels=labels,
961edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir                  predictions=logistic,
962edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir                  weights=weights,
963edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir                  curve='PR',
964edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir                  name=keys.AUC_PR)
96534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      }
9663f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower      if regularization_loss is not None:
9673f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower        metric_ops[_summary_key(self._name, keys.LOSS_REGULARIZATION)] = (
9683f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower            metrics_lib.mean(
9693f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower                values=regularization_loss,
9703f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower                name=keys.LOSS_REGULARIZATION))
97134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      for threshold in self._thresholds:
97234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        accuracy_key = keys.ACCURACY_AT_THRESHOLD % threshold
973abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower        metric_ops[_summary_key(self._name,
9740fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower                                accuracy_key)] = _accuracy_at_threshold(
9750fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower                                    labels=labels,
9760fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower                                    predictions=logistic,
9770fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower                                    weights=weights,
9780fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower                                    threshold=threshold,
9790fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower                                    name=accuracy_key)
98034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        # Precision for positive examples.
98134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        precision_key = keys.PRECISION_AT_THRESHOLD % threshold
982abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower        metric_ops[_summary_key(self._name,
9830fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower                                precision_key)] = _precision_at_threshold(
9840fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower                                    labels=labels,
9850fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower                                    predictions=logistic,
9860fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower                                    weights=weights,
9870fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower                                    threshold=threshold,
9880fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower                                    name=precision_key)
98934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        # Recall for positive examples.
99034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        recall_key = keys.RECALL_AT_THRESHOLD % threshold
991abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower        metric_ops[_summary_key(self._name,
9920fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower                                recall_key)] = _recall_at_threshold(
9930fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower                                    labels=labels,
9940fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower                                    predictions=logistic,
9950fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower                                    weights=weights,
9960fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower                                    threshold=threshold,
9970fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower                                    name=recall_key)
99834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      return metric_ops
99934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
10008d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower  def create_loss(self, features, mode, logits, labels):
10018d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower    """See `Head`."""
100299d51b8da87ba462f66a6be90212677f2cae9e32A. Unique TensorFlower    del mode  # Unused for this head.
1003b992761e569ac505e156127b64b91eba2f0953cbA. Unique TensorFlower    logits = ops.convert_to_tensor(logits)
1004b992761e569ac505e156127b64b91eba2f0953cbA. Unique TensorFlower    labels = _check_dense_labels_match_logits_and_reshape(
1005b992761e569ac505e156127b64b91eba2f0953cbA. Unique TensorFlower        labels=labels, logits=logits, expected_labels_dimension=1)
10068d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower    if self._label_vocabulary is not None:
10078d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower      labels = lookup_ops.index_table_from_tensor(
10088d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower          vocabulary_list=tuple(self._label_vocabulary),
10098d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower          name='class_id_lookup').lookup(labels)
10108d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower    labels = math_ops.to_float(labels)
10118d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower    labels = _assert_range(labels, 2)
10123942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower    if self._loss_fn:
10133942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower      unweighted_loss = _call_loss_fn(
10143942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower          loss_fn=self._loss_fn, labels=labels, logits=logits,
10153942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower          features=features, expected_loss_dim=1)
10163942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower    else:
10173942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower      unweighted_loss = nn.sigmoid_cross_entropy_with_logits(
10183942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower          labels=labels, logits=logits)
1019d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower    weights = _get_weights_and_check_match_logits(
1020d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower        features=features, weight_column=self._weight_column, logits=logits)
102120d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower    training_loss = losses.compute_weighted_loss(
102220d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower        unweighted_loss, weights=weights, reduction=self._loss_reduction)
102399d51b8da87ba462f66a6be90212677f2cae9e32A. Unique TensorFlower    return LossSpec(
102420d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower        training_loss=training_loss,
102520d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower        unreduced_loss=unweighted_loss,
102620d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower        weights=weights,
10278d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower        processed_labels=labels)
10288d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower
102934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  def create_estimator_spec(
10303f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower      self, features, mode, logits, labels=None, train_op_fn=None,
10313f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower      regularization_losses=None):
10323f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower    """Returns an `EstimatorSpec`.
10333f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower
10343f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower    Args:
10353f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower      features: Input `dict` of `Tensor` or `SparseTensor` objects.
10363f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower      mode: Estimator's `ModeKeys`.
10373f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower      logits: logits `Tensor` with shape `[D0, D1, ... DN, 1]`. For many
10383f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower        applications, the shape is `[batch_size, 1]`.
10393f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower      labels: Labels integer or string `Tensor` with shape matching `logits`,
10403f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower        namely `[D0, D1, ... DN, 1]` or `[D0, D1, ... DN]`. `labels` is required
10413f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower        argument when `mode` equals `TRAIN` or `EVAL`.
10423f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower      train_op_fn: Function that takes a scalar loss `Tensor` and returns
10433f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower        `train_op`. Required in TRAIN mode.
10443f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower      regularization_losses: A list of additional scalar losses to be added to
10453f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower        the training loss, such as regularization losses. These losses are
10463f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower        usually expressed as a batch average, so for best results users need to
10473f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower        set `loss_reduction=SUM_OVER_BATCH_SIZE` or
10483f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower        `loss_reduction=SUM_OVER_NONZERO_WEIGHTS` when creating the head to
10493f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower        avoid scaling errors.
10503f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower    Returns:
10513f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower      `EstimatorSpec`.
10523f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower    Raises:
10533f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower      ValueError: If `train_op_fn` is `None` in TRAIN mode.
10543f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower    """
1055169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir    # Predict.
105601c76110eb3cb1c378c9d7a14ca9f838bad6c7d1A. Unique TensorFlower    with ops.name_scope(self._name, 'head'):
1057169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir      with ops.name_scope(None, 'predictions', (logits,)):
1058169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir        pred_keys = prediction_keys.PredictionKeys
1059b992761e569ac505e156127b64b91eba2f0953cbA. Unique TensorFlower        logits = _check_logits_final_dim(logits, self.logits_dimension)
1060169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir        logistic = math_ops.sigmoid(logits, name=pred_keys.LOGISTIC)
1061169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir        two_class_logits = array_ops.concat(
1062b992761e569ac505e156127b64b91eba2f0953cbA. Unique TensorFlower            (array_ops.zeros_like(logits), logits),
1063b992761e569ac505e156127b64b91eba2f0953cbA. Unique TensorFlower            axis=-1, name='two_class_logits')
1064df22cf83a21b62838ecf6f3a1c8a9c30ab20d482A. Unique TensorFlower        probabilities = nn.softmax(
1065df22cf83a21b62838ecf6f3a1c8a9c30ab20d482A. Unique TensorFlower            two_class_logits, name=pred_keys.PROBABILITIES)
1066b992761e569ac505e156127b64b91eba2f0953cbA. Unique TensorFlower        class_ids = math_ops.argmax(
1067b992761e569ac505e156127b64b91eba2f0953cbA. Unique TensorFlower            two_class_logits, axis=-1, name=pred_keys.CLASS_IDS)
1068b992761e569ac505e156127b64b91eba2f0953cbA. Unique TensorFlower        class_ids = array_ops.expand_dims(class_ids, axis=-1)
1069169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir        if self._label_vocabulary:
1070169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir          table = lookup_ops.index_to_string_table_from_tensor(
1071169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir              vocabulary_list=self._label_vocabulary,
1072169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir              name='class_string_lookup')
1073169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir          classes = table.lookup(class_ids)
1074169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir        else:
1075169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir          classes = string_ops.as_string(class_ids, name='str_classes')
1076169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir        predictions = {
1077169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir            pred_keys.LOGITS: logits,
1078169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir            pred_keys.LOGISTIC: logistic,
1079df22cf83a21b62838ecf6f3a1c8a9c30ab20d482A. Unique TensorFlower            pred_keys.PROBABILITIES: probabilities,
1080169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir            pred_keys.CLASS_IDS: class_ids,
1081169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir            pred_keys.CLASSES: classes,
1082169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir        }
108334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      if mode == model_fn.ModeKeys.PREDICT:
1084df22cf83a21b62838ecf6f3a1c8a9c30ab20d482A. Unique TensorFlower        classifier_output = _classification_output(
1085df22cf83a21b62838ecf6f3a1c8a9c30ab20d482A. Unique TensorFlower            scores=probabilities, n_classes=2,
1086df22cf83a21b62838ecf6f3a1c8a9c30ab20d482A. Unique TensorFlower            label_vocabulary=self._label_vocabulary)
108734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        return model_fn.EstimatorSpec(
108834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower            mode=model_fn.ModeKeys.PREDICT,
108934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower            predictions=predictions,
1090edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir            export_outputs={
10914db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel                _DEFAULT_SERVING_KEY: classifier_output,
10924db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel                _CLASSIFY_SERVING_KEY: classifier_output,
10934db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel                _REGRESS_SERVING_KEY: export_output.RegressionOutput(
10944db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel                    value=logistic),
10954db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel                _PREDICT_SERVING_KEY: export_output.PredictOutput(predictions)
1096edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir            })
109734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
109820d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower      (training_loss, unreduced_loss, weights, processed_labels) = (
109920d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower          self.create_loss(
110020d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower              features=features, mode=mode, logits=logits, labels=labels))
11013f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower      if regularization_losses:
11023f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower        regularization_loss = math_ops.add_n(regularization_losses)
11033f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower        regularized_training_loss = math_ops.add_n(
11043f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower            [training_loss, regularization_loss])
11053f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower      else:
11063f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower        regularization_loss = None
11073f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower        regularized_training_loss = training_loss
110899d51b8da87ba462f66a6be90212677f2cae9e32A. Unique TensorFlower
110934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      # Eval.
111034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      if mode == model_fn.ModeKeys.EVAL:
111134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        return model_fn.EstimatorSpec(
111234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower            mode=model_fn.ModeKeys.EVAL,
111334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower            predictions=predictions,
11143f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower            loss=regularized_training_loss,
111534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower            eval_metric_ops=self._eval_metric_ops(
11168d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower                labels=processed_labels,
111734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower                logits=logits,
111834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower                logistic=logistic,
1119edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir                class_ids=class_ids,
1120b992761e569ac505e156127b64b91eba2f0953cbA. Unique TensorFlower                weights=weights,
11213f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower                unreduced_loss=unreduced_loss,
11223f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower                regularization_loss=regularization_loss))
112334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
112434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      # Train.
112534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      if train_op_fn is None:
112634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        raise ValueError('train_op_fn can not be None.')
112720d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower      # Only summarize mean_loss for SUM reduction to preserve backwards
112820d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower      # compatibility. Otherwise skip it to avoid unnecessary computation.
112920d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower      if self._loss_reduction == losses.Reduction.SUM:
113020d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower        example_weight_sum = math_ops.reduce_sum(
113120d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower            weights * array_ops.ones_like(unreduced_loss))
113220d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower        mean_loss = training_loss / example_weight_sum
113320d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower      else:
113420d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower        mean_loss = None
1135169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir    with ops.name_scope(''):
11363f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower      keys = metric_keys.MetricKeys
11370fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower      summary.scalar(
11383f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower          _summary_key(self._name, keys.LOSS),
11393f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower          regularized_training_loss)
114020d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower      if mean_loss is not None:
114120d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower        summary.scalar(
11423f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower            _summary_key(self._name, keys.LOSS_MEAN), mean_loss)
11433f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower      if regularization_loss is not None:
11443f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower        summary.scalar(
11453f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower            _summary_key(self._name, keys.LOSS_REGULARIZATION),
11463f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower            regularization_loss)
1147169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir    return model_fn.EstimatorSpec(
1148169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir        mode=model_fn.ModeKeys.TRAIN,
1149169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir        predictions=predictions,
11503f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower        loss=regularized_training_loss,
11513f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower        train_op=train_op_fn(regularized_training_loss))
115234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
115334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
115420d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlowerdef _regression_head_with_mean_squared_error_loss(
115520d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower    weight_column=None,
115620d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower    label_dimension=1,
115720d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower    loss_reduction=losses.Reduction.SUM,
11583942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower    loss_fn=None,
115920d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower    name=None):
1160d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower  """Creates a `_Head` for regression using the `mean_squared_error` loss.
116134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
116267c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower  The loss is the weighted sum over all input dimensions. Namely, if the input
116367c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower  labels have shape `[batch_size, label_dimension]`, the loss is the weighted
116467c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower  sum over both `batch_size` and `label_dimension`.
116567c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower
116667c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower  The head expects `logits` with shape `[D0, D1, ... DN, label_dimension]`.
116767c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower  In many applications, the shape is `[batch_size, label_dimension]`.
116867c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower
116967c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower  The `labels` shape must match `logits`, namely
117067c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower  `[D0, D1, ... DN, label_dimension]`. If `label_dimension=1`, shape
117167c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower  `[D0, D1, ... DN]` is also supported.
117267c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower
117367c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower  If `weight_column` is specified, weights must be of shape
117467c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower  `[D0, D1, ... DN]`, `[D0, D1, ... DN, 1]` or
117567c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower  `[D0, D1, ... DN, label_dimension]`.
117667c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower
11773942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower  Also supports custom `loss_fn`. `loss_fn` takes `(labels, logits)` or
11783942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower  `(labels, logits, features)` as arguments and returns unreduced loss with
11793942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower  shape `[D0, D1, ... DN, label_dimension]`.
11803942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower
118134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  Args:
1182d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispir    weight_column: A string or a `_NumericColumn` created by
1183d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispir      `tf.feature_column.numeric_column` defining feature column representing
118434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      weights. It is used to down weight or boost examples during training. It
118534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      will be multiplied by the loss of the example.
118634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    label_dimension: Number of regression labels per example. This is the size
118734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      of the last dimension of the labels `Tensor` (typically, this has shape
118834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      `[batch_size, label_dimension]`).
118920d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower    loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to
119020d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower      reduce training loss over batch. Defaults to `SUM`.
11913942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower    loss_fn: Optional loss function.
1192abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower    name: name of the head. If provided, summary and metrics keys will be
119301c76110eb3cb1c378c9d7a14ca9f838bad6c7d1A. Unique TensorFlower      suffixed by `"/" + name`. Also used as `name_scope` when creating ops.
119434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
119534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  Returns:
119634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    An instance of `_Head` for linear regression.
119720d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower
119820d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower  Raises:
119920d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower    ValueError: If `label_dimension` or `loss_reduction` is invalid.
120034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  """
120120d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower  if (loss_reduction not in losses.Reduction.all() or
120220d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower      loss_reduction == losses.Reduction.NONE):
120320d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower    raise ValueError('Invalid loss_reduction: {}'.format(loss_reduction))
12043942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower  if loss_fn:
12053942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower    _validate_loss_fn_args(loss_fn)
120634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  return _RegressionHeadWithMeanSquaredErrorLoss(
12070fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower      weight_column=weight_column,
12080fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower      label_dimension=label_dimension,
120920d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower      loss_reduction=loss_reduction,
12103942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower      loss_fn=loss_fn,
1211abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower      name=name)
121234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
121334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
121434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerclass _RegressionHeadWithMeanSquaredErrorLoss(_Head):
121534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  """`Head` for regression using the mean squared loss."""
121634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
121720d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower  def __init__(
121820d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower      self,
121920d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower      label_dimension,
122020d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower      weight_column=None,
122120d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower      loss_reduction=losses.Reduction.SUM,
12223942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower      loss_fn=None,
122320d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower      name=None):
1224d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispir    """`Head` for regression."""
122534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    if label_dimension < 1:
122634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      raise ValueError('Invalid label_dimension %s.' % label_dimension)
122734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    self._logits_dimension = label_dimension
1228d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispir    self._weight_column = weight_column
122920d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower    self._loss_reduction = loss_reduction
12303942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower    self._loss_fn = loss_fn
1231abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower    self._name = name
1232abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower
1233abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower  @property
1234abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower  def name(self):
1235abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower    return self._name
123634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
123734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  @property
123834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  def logits_dimension(self):
123934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    return self._logits_dimension
124034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
12418d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower  def create_loss(self, features, mode, logits, labels):
12428d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower    """See `Head`."""
124399d51b8da87ba462f66a6be90212677f2cae9e32A. Unique TensorFlower    del mode  # Unused for this head.
124467c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower    logits = ops.convert_to_tensor(logits)
124567c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower    labels = _check_dense_labels_match_logits_and_reshape(
124667c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower        labels=labels, logits=logits,
124767c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower        expected_labels_dimension=self._logits_dimension)
124827a7e5cfdb4ef9d5e3b710873c428cb44630622aA. Unique TensorFlower    labels = math_ops.to_float(labels)
12493942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower    if self._loss_fn:
12503942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower      unweighted_loss = _call_loss_fn(
12513942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower          loss_fn=self._loss_fn, labels=labels, logits=logits,
12523942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower          features=features, expected_loss_dim=self._logits_dimension)
12533942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower    else:
12543942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower      unweighted_loss = losses.mean_squared_error(
12553942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower          labels=labels, predictions=logits, reduction=losses.Reduction.NONE)
1256d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower    weights = _get_weights_and_check_match_logits(
1257d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower        features=features, weight_column=self._weight_column, logits=logits,
1258d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower        allow_per_logit_weights=True)
125920d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower    training_loss = losses.compute_weighted_loss(
126020d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower        unweighted_loss, weights=weights, reduction=self._loss_reduction)
126199d51b8da87ba462f66a6be90212677f2cae9e32A. Unique TensorFlower    return LossSpec(
126220d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower        training_loss=training_loss,
126320d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower        unreduced_loss=unweighted_loss,
126420d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower        weights=weights,
12658d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower        processed_labels=labels)
12668d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower
126734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  def create_estimator_spec(
12683f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower      self, features, mode, logits, labels=None, train_op_fn=None,
12693f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower      regularization_losses=None):
127067c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower    """Returns an `EstimatorSpec`.
127167c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower
127267c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower    Args:
127367c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower      features: Input `dict` of `Tensor` or `SparseTensor` objects.
127467c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower      mode: Estimator's `ModeKeys`.
127567c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower      logits: logits `Tensor` with shape `[D0, D1, ... DN, logits_dimension]`.
127667c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower        For many applications, the shape is `[batch_size, logits_dimension]`.
127767c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower      labels: Labels `Tensor` with shape matching `logits`, namely
127867c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower        `[D0, D1, ... DN, logits_dimension]`. When `logits_dimension=1`, shape
127967c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower        `[D0, D1, ... DN]` is also supported. `labels` is required argument when
128067c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower        `mode` equals `TRAIN` or `EVAL`.
128167c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower      train_op_fn: Function that takes a scalar loss `Tensor` and returns
128267c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower        `train_op`. Required in TRAIN mode.
12833f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower      regularization_losses: A list of additional scalar losses to be added to
12843f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower        the training loss, such as regularization losses. These losses are
12853f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower        usually expressed as a batch average, so for best results users need to
12863f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower        set `loss_reduction=SUM_OVER_BATCH_SIZE` or
12873f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower        `loss_reduction=SUM_OVER_NONZERO_WEIGHTS` when creating the head to
12883f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower        avoid scaling errors.
128967c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower    Returns:
129067c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower      `EstimatorSpec`.
129167c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower    Raises:
129267c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower      ValueError: If `train_op_fn` is `None` in TRAIN mode.
129367c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower    """
1294169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir    # Predict.
129501c76110eb3cb1c378c9d7a14ca9f838bad6c7d1A. Unique TensorFlower    with ops.name_scope(self._name, 'head'):
129667c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower      logits = _check_logits_final_dim(logits, self._logits_dimension)
129734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      predictions = {prediction_keys.PredictionKeys.PREDICTIONS: logits}
129834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      if mode == model_fn.ModeKeys.PREDICT:
12994db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel        regression_output = export_output.RegressionOutput(value=logits)
130034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        return model_fn.EstimatorSpec(
130134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower            mode=model_fn.ModeKeys.PREDICT,
130234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower            predictions=predictions,
13034db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel            export_outputs={
13044db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel                _DEFAULT_SERVING_KEY: regression_output,
13054db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel                _REGRESS_SERVING_KEY: regression_output,
13064db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel                _PREDICT_SERVING_KEY: export_output.PredictOutput(predictions)
13074db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel            })
130834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
130920d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower      training_loss, unreduced_loss, weights, _ = self.create_loss(
13108d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower          features=features, mode=mode, logits=logits, labels=labels)
13113f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower      if regularization_losses:
13123f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower        regularization_loss = math_ops.add_n(regularization_losses)
13133f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower        regularized_training_loss = math_ops.add_n(
13143f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower            [training_loss, regularization_loss])
13153f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower      else:
13163f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower        regularization_loss = None
13173f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower        regularized_training_loss = training_loss
131899d51b8da87ba462f66a6be90212677f2cae9e32A. Unique TensorFlower
131999d51b8da87ba462f66a6be90212677f2cae9e32A. Unique TensorFlower      # Eval.
132034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      if mode == model_fn.ModeKeys.EVAL:
13213f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower        keys = metric_keys.MetricKeys
132234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        # Estimator already adds a metric for loss.
132334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        eval_metric_ops = {
13243f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower            _summary_key(self._name, keys.LOSS_MEAN):
132599d51b8da87ba462f66a6be90212677f2cae9e32A. Unique TensorFlower                metrics_lib.mean(
132620d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower                    values=unreduced_loss,
132720d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower                    weights=weights)
132834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        }
13293f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower        if regularization_loss is not None:
13303f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower          regularization_loss_key = _summary_key(
13313f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower              self._name, keys.LOSS_REGULARIZATION)
13323f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower          eval_metric_ops[regularization_loss_key] = metrics_lib.mean(
13333f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower              values=regularization_loss,
13343f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower              name=keys.LOSS_REGULARIZATION)
133534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        return model_fn.EstimatorSpec(
133634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower            mode=model_fn.ModeKeys.EVAL,
133734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower            predictions=predictions,
13383f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower            loss=regularized_training_loss,
133934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower            eval_metric_ops=eval_metric_ops)
134034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
134134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      # Train.
134234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      if train_op_fn is None:
134334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        raise ValueError('train_op_fn can not be None.')
134420d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower      # Only summarize mean_loss for SUM reduction to preserve backwards
134520d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower      # compatibility. Otherwise skip it to avoid unnecessary computation.
134620d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower      if self._loss_reduction == losses.Reduction.SUM:
134720d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower        example_weight_sum = math_ops.reduce_sum(
134820d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower            weights * array_ops.ones_like(unreduced_loss))
134920d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower        mean_loss = training_loss / example_weight_sum
135020d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower      else:
135120d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower        mean_loss = None
1352169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir    with ops.name_scope(''):
13533f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower      keys = metric_keys.MetricKeys
13540fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower      summary.scalar(
13553f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower          _summary_key(self._name, keys.LOSS),
13563f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower          regularized_training_loss)
135720d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower      if mean_loss is not None:
135820d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower        summary.scalar(
13593f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower            _summary_key(self._name, keys.LOSS_MEAN), mean_loss)
13603f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower      if regularization_loss is not None:
13613f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower        summary.scalar(
13623f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower            _summary_key(self._name, keys.LOSS_REGULARIZATION),
13633f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower            regularization_loss)
1364169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir    return model_fn.EstimatorSpec(
1365169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir        mode=model_fn.ModeKeys.TRAIN,
1366169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir        predictions=predictions,
13673f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower        loss=regularized_training_loss,
13683f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower        train_op=train_op_fn(regularized_training_loss))
1369edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir
1370edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir
1371d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlowerdef _assert_range(labels, n_classes, message=None):
1372169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir  with ops.name_scope(None, 'assert_range', (labels,)):
1373169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir    assert_less = check_ops.assert_less(
1374169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir        labels,
1375169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir        ops.convert_to_tensor(n_classes, dtype=labels.dtype),
1376d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower        message=message or 'Label IDs must < n_classes')
1377169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir    assert_greater = check_ops.assert_non_negative(
1378d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower        labels, message=message or 'Label IDs must >= 0')
1379169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir    with ops.control_dependencies((assert_less, assert_greater)):
1380169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir      return array_ops.identity(labels)
1381d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispir
1382d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispir
1383d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower# TODO(b/69000400): Delete this method.
1384d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispirdef _weights(features, weight_column):
1385d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispir  """Fetches weights from features."""
1386169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir  with ops.name_scope(None, 'weights', values=features.values()):
1387169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir    if weight_column is None:
1388169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir      return 1.
1389169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir    if isinstance(weight_column, six.string_types):
139067c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower      weight_column = feature_column_lib.numeric_column(
139167c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower          key=weight_column, shape=(1,))
1392169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir    if not isinstance(weight_column, feature_column_lib._NumericColumn):  # pylint: disable=protected-access
1393169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir      raise TypeError('Weight column must be either a string or _NumericColumn.'
1394169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir                      ' Given type: {}.'.format(type(weight_column)))
1395169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir    weights = weight_column._get_dense_tensor(  # pylint: disable=protected-access
1396169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir        feature_column_lib._LazyBuilder(features))  # pylint: disable=protected-access
1397169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir    if not (weights.dtype.is_floating or weights.dtype.is_integer):
1398169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir      raise ValueError('Weight column should be castable to float. '
1399169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir                       'Given dtype: {}'.format(weights.dtype))
140067c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower    return math_ops.to_float(weights, name='weights')
1401