head.py revision 2dab9fd3c89f47dbb0b5f4368084cebb56e03a09
134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower#
334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower# Licensed under the Apache License, Version 2.0 (the "License");
434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower# you may not use this file except in compliance with the License.
534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower# You may obtain a copy of the License at
634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower#
734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower#     http://www.apache.org/licenses/LICENSE-2.0
834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower#
934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower# Unless required by applicable law or agreed to in writing, software
1034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower# distributed under the License is distributed on an "AS IS" BASIS,
1134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower# See the License for the specific language governing permissions and
1334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower# limitations under the License.
1434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower# ==============================================================================
1534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower"""Abstractions for the head(s) of a model."""
1634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
1734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerfrom __future__ import absolute_import
1834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerfrom __future__ import division
1934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerfrom __future__ import print_function
2034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
2134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerimport abc
22e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Caiimport collections
2334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
2434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerimport six
2534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
2634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerfrom tensorflow.python.estimator import model_fn
2734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerfrom tensorflow.python.estimator.canned import metric_keys
2834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerfrom tensorflow.python.estimator.canned import prediction_keys
2934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerfrom tensorflow.python.estimator.export import export_output
30d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispirfrom tensorflow.python.feature_column import feature_column as feature_column_lib
3134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerfrom tensorflow.python.framework import dtypes
3234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerfrom tensorflow.python.framework import ops
3334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerfrom tensorflow.python.framework import sparse_tensor
3434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerfrom tensorflow.python.ops import array_ops
3534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerfrom tensorflow.python.ops import check_ops
36ce21f5e372360451528c9716a743b65308421423Mustafa Ispirfrom tensorflow.python.ops import lookup_ops
3734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerfrom tensorflow.python.ops import math_ops
3834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerfrom tensorflow.python.ops import metrics as metrics_lib
3934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerfrom tensorflow.python.ops import nn
4034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerfrom tensorflow.python.ops import string_ops
4134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerfrom tensorflow.python.ops import weights_broadcast_ops
4234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerfrom tensorflow.python.ops.losses import losses
4334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerfrom tensorflow.python.platform import tf_logging as logging
4402ebc44be9a660a54793973110ae26cf948ffceaMustafa Ispirfrom tensorflow.python.saved_model import signature_constants
45169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispirfrom tensorflow.python.summary import summary
4602ebc44be9a660a54793973110ae26cf948ffceaMustafa Ispir
4702ebc44be9a660a54793973110ae26cf948ffceaMustafa Ispir_DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
4834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
494db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel# The above default is defined by TF Serving, but these next three are just
504db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel# a local convention without any special meaning.
514db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel_CLASSIFY_SERVING_KEY = 'classification'
524db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel_REGRESS_SERVING_KEY = 'regression'
534db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel_PREDICT_SERVING_KEY = 'predict'
544db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel
5534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
568d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlowerLossAndLabels = collections.namedtuple('LossAndLabels',
578d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower                                       ['unweighted_loss', 'processed_labels'])
588d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower
598d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower
600fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlowerdef _summary_key(head_name, val):
610fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower  return '%s/%s' % (val, head_name) if head_name else val
620fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower
630fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower
6434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerclass _Head(object):
6534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  """Interface for the head/top of a model.
6634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
6734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  Given logits (or output of a hidden layer), a Head knows how to compute
6834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  predictions, loss, train_op, metrics and export outputs. It is meant to:
6934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
7034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  1. Simplify writing model_fn and to make model_fn more configurable
7134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  2. Support wide range of machine learning models. Since most heads can work
7234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower     with logits, they can support DNN, RNN, Wide, Wide&Deep,
7334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower     Global objectives, Gradient boosted trees and many other types
7434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower     of machine learning models.
7534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
7634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  Common usage:
7734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  Here is simplified model_fn to build a DNN regression model.
7834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    ```python
7934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    def _my_dnn_model_fn(features, labels, mode, params, config=None):
8034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      # Optionally your callers can pass head to model_fn as a param.
8134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      head = tf.contrib.learn.regression_head(...)
8234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      input = tf.contrib.layers.input_from_feature_columns(features, ...)
8334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      last_hidden_layer_out = tf.contrib.layers.stack(
8434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower          input, tf.contrib.layers.fully_connected, [1000, 500])
8534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      logits = tf.contrib.layers.fully_connected(
8634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower          last_hidden_layer_out, head.logits_dimension, activation_fn=None)
8734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
8834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      def _train_op_fn(loss):
8934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        return optimizer.minimize(loss)
9034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
9134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      return head.create_estimator_spec(
9234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower          features=features,
9334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower          labels=labels,
9434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower          mode=mode,
9534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower          logits=logits,
9634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower          train_op_fn=_train_op_fn)
9734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    ```
9834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
9934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  There are cases where computing and applying gradients can not be meaningfully
10034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  captured with train_op_fn we support (for example, with sync optimizer). In
10134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  such case, you can take the responsibility on your own. Here is a common
10234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  use case,
10334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    ```python
10434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    estimator_spec = head.create_estimator_spec(
10534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        features=features,
10634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        labels=labels,
10734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        mode=mode,
10834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        logits=logits,
10934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        train_op_fn=tf.contrib.learn.no_op_train_fn)
11034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    if mode == model_fn.ModeKeys.TRAIN:
11134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      optimizer = ...
11234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      sync = tf.train.SyncReplicasOptimizer(opt=optimizer, ...)
11334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      update_op = tf.contrib.layers.optimize_loss(optimizer=sync,
11434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower                                                  loss=estimator_spec.loss, ...)
11534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      hooks = [sync.make_session_run_hook(is_chief)]
11634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      ... upate train_op and hooks in EstimatorSpec and return
11734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    ```
11834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  """
11934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  __metaclass__ = abc.ABCMeta
12034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
12134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  @abc.abstractproperty
122abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower  def name(self):
123abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower    """The name of this head.
124abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower
125abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower    Returns:
126abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower      A string.
127abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower    """
128abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower    raise NotImplementedError('Calling an abstract method.')
129abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower
130abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower  @abc.abstractproperty
13134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  def logits_dimension(self):
13234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    """Size of the last dimension of the logits `Tensor`.
13334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
13434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    Typically, logits is of shape `[batch_size, logits_dimension]`.
13534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
13634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    Returns:
13734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      The expected size of the `logits` tensor.
13834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    """
13934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    raise NotImplementedError('Calling an abstract method.')
14034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
14134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  @abc.abstractmethod
1428d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower  def create_loss(self, features, mode, logits, labels):
1438d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower    """Returns a loss Tensor from provided logits.
1448d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower
1458d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower    This function is designed to be used by framework developers.  Almost all
1468d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower    users should use create_estimator_spec(), which calls this internally.
1478d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower    `mode` and `features` are most likely not used, but some Head
1488d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower    implementations may require them.
1498d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower
1508d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower    Args:
1518d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower      features: Input `dict` of `Tensor` objects.
1528d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower      mode: Estimator's `ModeKeys`.
1538d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower      logits: logits `Tensor` to be used for loss construction.
1542dab9fd3c89f47dbb0b5f4368084cebb56e03a09A. Unique TensorFlower      labels: Labels `Tensor`, or `dict` of same.
1558d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower
1568d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower    Returns:
1578d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower      A LossAndLabels that contains the `Tensor` representing the loss and
1588d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower      possibly processed labels (e.g. vocabulary lookup, shape manipulation,
1598d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower      etc.), to be extendable in the future.
1608d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower    """
1618d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower    raise NotImplementedError('Calling an abstract method.')
1628d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower
1638d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower  @abc.abstractmethod
16434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  def create_estimator_spec(
16534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      self, features, mode, logits, labels=None, train_op_fn=None):
16634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    """Returns `EstimatorSpec` that a model_fn can return.
16734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
16834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    Please note that,
16934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    + All args must be passed via name.
17034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
17134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    Args:
17234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      features: Input `dict` of `Tensor` objects.
17334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      mode: Estimator's `ModeKeys`.
17434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      logits: logits `Tensor` to be used by the head.
17534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      labels: Labels `Tensor`, or `dict` of same.
17634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      train_op_fn: Function that takes a scalar loss `Tensor` and returns an op
17734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower          to optimize the model with the loss. This is used in TRAIN mode and
17834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower          must not be None. None is allowed in other modes. If you want to
17934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower          optimize loss yourself you can pass `no_op_train_fn` and then use
18034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower          EstimatorSpec.loss to compute and apply gradients.
18134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
18234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    Returns:
18334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      `EstimatorSpec`.
18434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    """
18534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    raise NotImplementedError('Calling an abstract method.')
18634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
18734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
18862cf561f1d32abf4f5b7fbdee6d106389994ff05Jianwei Xiedef _maybe_expand_dim(tensor):
18962cf561f1d32abf4f5b7fbdee6d106389994ff05Jianwei Xie  """Expand the dim of `tensor` with static rank 1."""
19062cf561f1d32abf4f5b7fbdee6d106389994ff05Jianwei Xie  with ops.name_scope(None, 'maybe_expand_dim', (tensor,)):
19162cf561f1d32abf4f5b7fbdee6d106389994ff05Jianwei Xie    tensor = sparse_tensor.convert_to_tensor_or_sparse_tensor(tensor)
19262cf561f1d32abf4f5b7fbdee6d106389994ff05Jianwei Xie    if isinstance(tensor, sparse_tensor.SparseTensor):
19362cf561f1d32abf4f5b7fbdee6d106389994ff05Jianwei Xie      raise ValueError('SparseTensor labels are not supported.')
19462cf561f1d32abf4f5b7fbdee6d106389994ff05Jianwei Xie    static_shape = tensor.shape
19562cf561f1d32abf4f5b7fbdee6d106389994ff05Jianwei Xie    if static_shape is None:
19662cf561f1d32abf4f5b7fbdee6d106389994ff05Jianwei Xie      return tensor
19762cf561f1d32abf4f5b7fbdee6d106389994ff05Jianwei Xie
19862cf561f1d32abf4f5b7fbdee6d106389994ff05Jianwei Xie    return (array_ops.expand_dims(tensor, -1) if static_shape.ndims == 1
19962cf561f1d32abf4f5b7fbdee6d106389994ff05Jianwei Xie            else tensor)
20062cf561f1d32abf4f5b7fbdee6d106389994ff05Jianwei Xie
20162cf561f1d32abf4f5b7fbdee6d106389994ff05Jianwei Xie
20234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerdef _check_labels(labels, expected_labels_dimension):
20334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  """Check labels type and shape."""
20434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  with ops.name_scope(None, 'labels', (labels,)) as scope:
20534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    labels = sparse_tensor.convert_to_tensor_or_sparse_tensor(labels)
20634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    if isinstance(labels, sparse_tensor.SparseTensor):
20734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      raise ValueError('SparseTensor labels are not supported.')
20834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    labels_shape = array_ops.shape(labels)
20934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    err_msg = 'labels shape must be [batch_size, {}]'.format(
21034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        expected_labels_dimension)
21134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    assert_rank = check_ops.assert_rank(labels, 2, message=err_msg)
21234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    with ops.control_dependencies([assert_rank]):
21334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      static_shape = labels.shape
21434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      if static_shape is not None:
21534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        dim1 = static_shape[1]
21634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        if (dim1 is not None) and (dim1 != expected_labels_dimension):
21734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower          raise ValueError(
2186e8d0c632dea30758c7cc343decdf8ab7956e59dA. Unique TensorFlower              'Mismatched label shape. '
2196e8d0c632dea30758c7cc343decdf8ab7956e59dA. Unique TensorFlower              'Classifier configured with n_classes=%s.  Received %s. '
2206e8d0c632dea30758c7cc343decdf8ab7956e59dA. Unique TensorFlower              'Suggested Fix: check your n_classes argument to the estimator '
2216e8d0c632dea30758c7cc343decdf8ab7956e59dA. Unique TensorFlower              'and/or the shape of your label.' %
2226e8d0c632dea30758c7cc343decdf8ab7956e59dA. Unique TensorFlower              (expected_labels_dimension, dim1))
22334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      assert_dimension = check_ops.assert_equal(
22434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower          expected_labels_dimension, labels_shape[1], message=err_msg)
22534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      with ops.control_dependencies([assert_dimension]):
22634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        return array_ops.identity(labels, name=scope)
22734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
22834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
22934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerdef _check_logits(logits, expected_logits_dimension):
23034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  """Check logits type and shape."""
23134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  with ops.name_scope(None, 'logits', (logits,)) as scope:
23234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    logits = math_ops.to_float(logits)
23334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    logits_shape = array_ops.shape(logits)
23434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    assert_rank = check_ops.assert_rank(
23534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        logits, 2, data=[logits_shape],
23634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        message='logits shape must be [batch_size, logits_dimension]')
23734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    with ops.control_dependencies([assert_rank]):
23834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      static_shape = logits.shape
23934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      if static_shape is not None:
24034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        dim1 = static_shape[1]
24134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        if (dim1 is not None) and (dim1 != expected_logits_dimension):
24234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower          raise ValueError(
24334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower              'logits shape must be [batch_size, logits_dimension], got %s.' %
24434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower              (static_shape,))
24534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      assert_dimension = check_ops.assert_equal(
24634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower          expected_logits_dimension, logits_shape[1], data=[logits_shape],
24734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower          message='logits shape must be [batch_size, logits_dimension]')
24834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      with ops.control_dependencies([assert_dimension]):
24934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        return array_ops.identity(logits, name=scope)
25034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
25134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
25234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerdef _indicator_labels_mean(labels, weights=None, name=None):
25334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  with ops.name_scope(name, 'labels_mean', (labels, weights)) as scope:
25434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    labels = math_ops.to_float(labels, name='labels')
25534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    if weights is not None:
25634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      weights = weights_broadcast_ops.broadcast_weights(weights, labels)
25734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    return metrics_lib.mean(labels, weights=weights, name=scope)
25834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
25934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
26034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerdef _accuracy_baseline(labels_mean):
26134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  """Return accuracy baseline based on labels mean.
26234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
26334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  This is the best the model could do by always predicting one class.
26434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
26534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  Args:
26634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    labels_mean: Tuple of value and update op.
26734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
26834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  Returns:
26934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    Tuple of value and update op.
27034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  """
27134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  with ops.name_scope(None, 'accuracy_baseline', labels_mean):
27234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    value, update_op = labels_mean
27334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    return (
27434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        math_ops.maximum(value, 1. - value, name='value'),
27534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        math_ops.maximum(update_op, 1 - update_op, name='update_op'))
27634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
27734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
27834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerdef _predictions_mean(predictions, weights=None, name=None):
27934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  with ops.name_scope(
28034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      name, 'predictions_mean', (predictions, weights)) as scope:
28134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    predictions = math_ops.to_float(predictions, name='predictions')
28234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    if weights is not None:
28334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      weights = weights_broadcast_ops.broadcast_weights(weights, predictions)
28434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    return metrics_lib.mean(predictions, weights=weights, name=scope)
28534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
28634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
28734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerdef _auc(labels, predictions, weights=None, curve='ROC', name=None):
28834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  with ops.name_scope(name, 'auc', (predictions, labels, weights)) as scope:
28934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    predictions = math_ops.to_float(predictions, name='predictions')
29034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    if labels.dtype.base_dtype != dtypes.bool:
29134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      logging.warning('Casting %s labels to bool.', labels.dtype)
29234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      labels = math_ops.cast(labels, dtypes.bool)
29334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    if weights is not None:
29434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      weights = weights_broadcast_ops.broadcast_weights(weights, predictions)
29534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    return metrics_lib.auc(
29634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        labels=labels, predictions=predictions, weights=weights, curve=curve,
29734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        name=scope)
29834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
29934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
30034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerdef _accuracy_at_threshold(labels, predictions, weights, threshold, name=None):
30134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  with ops.name_scope(
30234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      name, 'accuracy_at_%s' % threshold,
30334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      (predictions, labels, weights, threshold)) as scope:
30434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    threshold_predictions = math_ops.to_float(
30534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        math_ops.greater_equal(predictions, threshold))
30634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    return metrics_lib.accuracy(
30734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        labels=labels, predictions=threshold_predictions, weights=weights,
30834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        name=scope)
30934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
31034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
31134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerdef _precision_at_threshold(labels, predictions, weights, threshold, name=None):
31234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  with ops.name_scope(
31334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      name, 'precision_at_%s' % threshold,
31434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      (predictions, labels, weights, threshold)) as scope:
31534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    precision_tensor, update_op = metrics_lib.precision_at_thresholds(
31634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        labels=labels, predictions=predictions, thresholds=(threshold,),
31734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        weights=weights, name=scope)
31834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    return array_ops.squeeze(precision_tensor), array_ops.squeeze(update_op)
31934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
32034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
32134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerdef _recall_at_threshold(labels, predictions, weights, threshold, name=None):
32234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  with ops.name_scope(
32334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      name, 'recall_at_%s' % threshold,
32434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      (predictions, labels, weights, threshold)) as scope:
32534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    precision_tensor, update_op = metrics_lib.recall_at_thresholds(
32634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        labels=labels, predictions=predictions, thresholds=(threshold,),
32734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        weights=weights, name=scope)
32834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    return array_ops.squeeze(precision_tensor), array_ops.squeeze(update_op)
32934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
33034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
331ce21f5e372360451528c9716a743b65308421423Mustafa Ispirdef _multi_class_head_with_softmax_cross_entropy_loss(n_classes,
332d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispir                                                      weight_column=None,
3330fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower                                                      label_vocabulary=None,
334abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower                                                      name=None):
33534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  """Creates a '_Head' for multi class classification.
33634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
33734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  This head expects to be fed integer labels specifying the class index.
33834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
33934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  Args:
34034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    n_classes: Number of classes, must be greater than 2 (for 2 classes, use
34134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      `_BinaryLogisticHeadWithSigmoidCrossEntropyLoss`).
342d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispir    weight_column: A string or a `_NumericColumn` created by
343d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispir      `tf.feature_column.numeric_column` defining feature column representing
34434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      weights. It is used to down weight or boost examples during training. It
34534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      will be multiplied by the loss of the example.
346ce21f5e372360451528c9716a743b65308421423Mustafa Ispir    label_vocabulary: A list of strings represents possible label values. If it
347ce21f5e372360451528c9716a743b65308421423Mustafa Ispir      is not given, that means labels are already encoded as integer within
348ce21f5e372360451528c9716a743b65308421423Mustafa Ispir      [0, n_classes). If given, labels must be string type and have any value in
349ce21f5e372360451528c9716a743b65308421423Mustafa Ispir      `label_vocabulary`. Also there will be errors if vocabulary is not
350ce21f5e372360451528c9716a743b65308421423Mustafa Ispir      provided and labels are string.
351abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower    name: name of the head. If provided, summary and metrics keys will be
352abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower      suffixed by `"/" + name`.
35334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
35434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  Returns:
3550815de21239955e346b562e899640649c8d2b9cbBenoit Steiner    An instance of `_Head` for multi class classification.
35634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
35734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  Raises:
35834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    ValueError: if `n_classes`, `metric_class_ids` or `label_keys` is invalid.
35934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  """
360edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir  if label_vocabulary is not None and not isinstance(label_vocabulary,
361edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir                                                     (list, tuple)):
362ce21f5e372360451528c9716a743b65308421423Mustafa Ispir    raise ValueError('label_vocabulary should be a list. Given type: {}'.format(
363ce21f5e372360451528c9716a743b65308421423Mustafa Ispir        type(label_vocabulary)))
364ce21f5e372360451528c9716a743b65308421423Mustafa Ispir
365d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispir  return _MultiClassHeadWithSoftmaxCrossEntropyLoss(n_classes, weight_column,
366abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower                                                    label_vocabulary, name)
36734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
36834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
36934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerclass _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head):
37034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  """See `_multi_class_head_with_softmax_cross_entropy_loss`."""
37134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
3720fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower  def __init__(self,
3730fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower               n_classes,
3740fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower               weight_column=None,
3750fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower               label_vocabulary=None,
376abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower               name=None):
37734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    if (n_classes is None) or (n_classes <= 2):
37834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      raise ValueError('n_classes must be > 2: %s.' % n_classes)
37934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    self._n_classes = n_classes
380d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispir    self._weight_column = weight_column
381ce21f5e372360451528c9716a743b65308421423Mustafa Ispir    self._label_vocabulary = label_vocabulary
382abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower    self._name = name
383abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower
384abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower  @property
385abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower  def name(self):
386abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower    return self._name
38734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
38834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  @property
38934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  def logits_dimension(self):
39034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    return self._n_classes
39134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
39234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  def _eval_metric_ops(self, labels, probabilities, logits,
39334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower                       class_ids, weights, unweighted_loss):
39434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    """Returns the Eval metric ops."""
39534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    with ops.name_scope(
39634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        None, 'metrics',
39734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        (labels, probabilities, logits, class_ids, weights, unweighted_loss)):
39834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      keys = metric_keys.MetricKeys
39934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      metric_ops = {
40034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower          # Estimator already adds a metric for loss.
40134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower          # TODO(xiejw): Any other metrics?
402abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower          _summary_key(self._name, keys.LOSS_MEAN):
4030fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower              metrics_lib.mean(
4040fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower                  unweighted_loss, weights=weights, name=keys.LOSS_MEAN),
405abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower          _summary_key(self._name, keys.ACCURACY):
4060fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower              metrics_lib.accuracy(
4070fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower                  labels=labels,
4080fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower                  predictions=class_ids,
4090fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower                  weights=weights,
4100fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower                  name=keys.ACCURACY),
41134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      }
41234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    return metric_ops
41334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
414ce21f5e372360451528c9716a743b65308421423Mustafa Ispir  def _label_ids(self, labels):
415ce21f5e372360451528c9716a743b65308421423Mustafa Ispir    """Converts labels to integer id space."""
416ce21f5e372360451528c9716a743b65308421423Mustafa Ispir    if self._label_vocabulary is None:
417ce21f5e372360451528c9716a743b65308421423Mustafa Ispir      if not labels.dtype.is_integer:
418ce21f5e372360451528c9716a743b65308421423Mustafa Ispir        raise ValueError('Labels dtype should be integer '
419ce21f5e372360451528c9716a743b65308421423Mustafa Ispir                         'Instead got %s.' % labels.dtype)
420ce21f5e372360451528c9716a743b65308421423Mustafa Ispir      label_ids = labels
421ce21f5e372360451528c9716a743b65308421423Mustafa Ispir    else:
422ce21f5e372360451528c9716a743b65308421423Mustafa Ispir      if labels.dtype != dtypes.string:
423ce21f5e372360451528c9716a743b65308421423Mustafa Ispir        raise ValueError('Labels dtype should be string if there is a '
424ce21f5e372360451528c9716a743b65308421423Mustafa Ispir                         'vocabulary. Instead got {}'.format(labels.dtype))
425ce21f5e372360451528c9716a743b65308421423Mustafa Ispir      label_ids = lookup_ops.index_table_from_tensor(
426ce21f5e372360451528c9716a743b65308421423Mustafa Ispir          vocabulary_list=tuple(self._label_vocabulary),
427ce21f5e372360451528c9716a743b65308421423Mustafa Ispir          name='class_id_lookup').lookup(labels)
428edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir    return _assert_range(label_ids, self._n_classes)
429ce21f5e372360451528c9716a743b65308421423Mustafa Ispir
4308d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower  def create_loss(self, features, mode, logits, labels):
4318d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower    """See `Head`."""
4328d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower    del mode, features  # Unused for this head.
4338d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower    label_ids = self._label_ids(_check_labels(_maybe_expand_dim(labels), 1))
4348d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower    unweighted_loss = losses.sparse_softmax_cross_entropy(
4358d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower        labels=label_ids, logits=logits, reduction=losses.Reduction.NONE)
4368d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower    # Restore the squeezed dim, so unweighted_loss matches the weights shape.
4378d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower    return LossAndLabels(
4388d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower        unweighted_loss=array_ops.expand_dims(unweighted_loss, axis=(1,)),
4398d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower        processed_labels=label_ids)
4408d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower
44134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  def create_estimator_spec(
44234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      self, features, mode, logits, labels=None, train_op_fn=None):
44334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    """See `Head`."""
444169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir    with ops.name_scope('head'):
44534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      logits = _check_logits(logits, self.logits_dimension)
44634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
44734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      # Predict.
44834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      pred_keys = prediction_keys.PredictionKeys
44934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      with ops.name_scope(None, 'predictions', (logits,)):
45034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        # class_ids's shape is [batch_size]
451ce21f5e372360451528c9716a743b65308421423Mustafa Ispir        class_ids = math_ops.argmax(logits, 1, name=pred_keys.CLASS_IDS)
452ce21f5e372360451528c9716a743b65308421423Mustafa Ispir        class_ids = array_ops.expand_dims(class_ids, axis=(1,))
453ce21f5e372360451528c9716a743b65308421423Mustafa Ispir        if self._label_vocabulary:
454ce21f5e372360451528c9716a743b65308421423Mustafa Ispir          table = lookup_ops.index_to_string_table_from_tensor(
455ce21f5e372360451528c9716a743b65308421423Mustafa Ispir              vocabulary_list=self._label_vocabulary,
456ce21f5e372360451528c9716a743b65308421423Mustafa Ispir              name='class_string_lookup')
457ce21f5e372360451528c9716a743b65308421423Mustafa Ispir          classes = table.lookup(class_ids)
458ce21f5e372360451528c9716a743b65308421423Mustafa Ispir        else:
459ce21f5e372360451528c9716a743b65308421423Mustafa Ispir          classes = string_ops.as_string(class_ids, name='str_classes')
460ce21f5e372360451528c9716a743b65308421423Mustafa Ispir
46134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        probabilities = nn.softmax(logits, name=pred_keys.PROBABILITIES)
46234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        predictions = {
46334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower            pred_keys.LOGITS: logits,
46434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower            pred_keys.PROBABILITIES: probabilities,
46534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower            # Expand to [batch_size, 1]
466ce21f5e372360451528c9716a743b65308421423Mustafa Ispir            pred_keys.CLASS_IDS: class_ids,
467ce21f5e372360451528c9716a743b65308421423Mustafa Ispir            pred_keys.CLASSES: classes,
46834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        }
46934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      if mode == model_fn.ModeKeys.PREDICT:
47034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        batch_size = array_ops.shape(probabilities)[0]
471ce21f5e372360451528c9716a743b65308421423Mustafa Ispir        export_class_list = self._label_vocabulary
472ce21f5e372360451528c9716a743b65308421423Mustafa Ispir        if not export_class_list:
473ce21f5e372360451528c9716a743b65308421423Mustafa Ispir          export_class_list = string_ops.as_string(
474ce21f5e372360451528c9716a743b65308421423Mustafa Ispir              math_ops.range(self._n_classes))
475ce21f5e372360451528c9716a743b65308421423Mustafa Ispir        export_output_classes = array_ops.tile(
476ce21f5e372360451528c9716a743b65308421423Mustafa Ispir            input=array_ops.expand_dims(input=export_class_list, axis=0),
47734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower            multiples=[batch_size, 1])
4784db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel        classifier_output = export_output.ClassificationOutput(
4794db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel            scores=probabilities,
4804db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel            # `ClassificationOutput` requires string classes.
4814db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel            classes=export_output_classes)
48234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        return model_fn.EstimatorSpec(
48334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower            mode=model_fn.ModeKeys.PREDICT,
48434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower            predictions=predictions,
485ce21f5e372360451528c9716a743b65308421423Mustafa Ispir            export_outputs={
4864db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel                _DEFAULT_SERVING_KEY: classifier_output,
4874db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel                _CLASSIFY_SERVING_KEY: classifier_output,
4884db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel                _PREDICT_SERVING_KEY: export_output.PredictOutput(predictions)
489ce21f5e372360451528c9716a743b65308421423Mustafa Ispir            })
49034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
49134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      # Eval.
4928d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower      unweighted_loss, label_ids = self.create_loss(
4938d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower          features=features, mode=mode, logits=logits, labels=labels)
494d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispir      weights = _weights(features, self._weight_column)
49534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      training_loss = losses.compute_weighted_loss(
49634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower          unweighted_loss, weights=weights, reduction=losses.Reduction.SUM)
49734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      if mode == model_fn.ModeKeys.EVAL:
49834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        return model_fn.EstimatorSpec(
49934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower            mode=model_fn.ModeKeys.EVAL,
50034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower            predictions=predictions,
50134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower            loss=training_loss,
50234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower            eval_metric_ops=self._eval_metric_ops(
503ce21f5e372360451528c9716a743b65308421423Mustafa Ispir                labels=label_ids,
50434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower                probabilities=probabilities,
50534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower                logits=logits,
50634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower                class_ids=class_ids,
50734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower                unweighted_loss=unweighted_loss,
50834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower                weights=weights))
50934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
51034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      # Train.
51134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      if train_op_fn is None:
51234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        raise ValueError('train_op_fn can not be None.')
513169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir    with ops.name_scope(''):
5140fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower      summary.scalar(
515abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower          _summary_key(self._name, metric_keys.MetricKeys.LOSS),
5160fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower          training_loss)
5170fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower      summary.scalar(
518abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower          _summary_key(self._name, metric_keys.MetricKeys.LOSS_MEAN),
5190fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower          losses.compute_weighted_loss(
5200fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower              unweighted_loss, weights=weights,
5210fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower              reduction=losses.Reduction.MEAN))
522169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir    return model_fn.EstimatorSpec(
523169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir        mode=model_fn.ModeKeys.TRAIN,
524169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir        predictions=predictions,
525169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir        loss=training_loss,
526169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir        train_op=train_op_fn(training_loss))
52734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
52834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
52934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerdef _binary_logistic_head_with_sigmoid_cross_entropy_loss(
530abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower    weight_column=None, thresholds=None, label_vocabulary=None, name=None):
53134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  """Creates a `Head` for single label binary classification.
53234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
53334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  This head uses `sigmoid_cross_entropy_with_logits` loss.
53434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
53534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  This head expects to be fed float labels of shape `(batch_size, 1)`.
53634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
53734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  Args:
538d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispir    weight_column: A string or a `_NumericColumn` created by
539d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispir      `tf.feature_column.numeric_column` defining feature column representing
54034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      weights. It is used to down weight or boost examples during training. It
54134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      will be multiplied by the loss of the example.
54234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    thresholds: Iterable of floats in the range `(0, 1)`. For binary
54334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      classification metrics such as precision and recall, an eval metric is
54434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      generated for each threshold value. This threshold is applied to the
54534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      logistic values to determine the binary classification (i.e., above the
54634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      threshold is `true`, below is `false`.
547edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir    label_vocabulary: A list of strings represents possible label values. If it
548edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir      is not given, that means labels are already encoded within [0, 1]. If
549edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir      given, labels must be string type and have any value in
550edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir      `label_vocabulary`. Also there will be errors if vocabulary is not
551edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir      provided and labels are string.
552abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower    name: name of the head. If provided, summary and metrics keys will be
553abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower      suffixed by `"/" + name`.
55434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
55534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  Returns:
55634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    An instance of `Head` for binary classification.
55734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
55834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  Raises:
55934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    ValueError: if `thresholds` contains a value outside of `(0, 1)`.
56034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  """
56179099d67761b3e56d1c3764cd34f97401571a211A. Unique TensorFlower  thresholds = tuple(thresholds) if thresholds else tuple()
562edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir  if label_vocabulary is not None and not isinstance(label_vocabulary,
563edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir                                                     (list, tuple)):
564edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir    raise ValueError('label_vocabulary should be a list. Given type: {}'.format(
565edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir        type(label_vocabulary)))
566edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir
56734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  for threshold in thresholds:
56834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    if (threshold <= 0.0) or (threshold >= 1.0):
56934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      raise ValueError('thresholds not in (0, 1): %s.' % (thresholds,))
57034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  return _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(
571d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispir      weight_column=weight_column,
572edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir      thresholds=thresholds,
5730fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower      label_vocabulary=label_vocabulary,
574abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower      name=name)
57534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
57634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
57734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerclass _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head):
57834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  """See `_binary_logistic_head_with_sigmoid_cross_entropy_loss`."""
57934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
5800fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower  def __init__(self,
5810fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower               weight_column=None,
5820fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower               thresholds=None,
5830fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower               label_vocabulary=None,
584abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower               name=None):
585d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispir    self._weight_column = weight_column
58679099d67761b3e56d1c3764cd34f97401571a211A. Unique TensorFlower    self._thresholds = thresholds
587edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir    self._label_vocabulary = label_vocabulary
588abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower    self._name = name
589abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower
590abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower  @property
591abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower  def name(self):
592abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower    return self._name
59334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
59434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  @property
59534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  def logits_dimension(self):
59634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    return 1
59734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
598edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir  def _eval_metric_ops(self,
599edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir                       labels,
600edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir                       logits,
601edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir                       logistic,
602edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir                       scores,
603edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir                       class_ids,
604edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir                       unweighted_loss,
605edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir                       weights=None):
606edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir    with ops.name_scope(None, 'metrics', (labels, logits, logistic, scores,
607edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir                                          class_ids, unweighted_loss, weights)):
60834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      keys = metric_keys.MetricKeys
60934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      labels_mean = _indicator_labels_mean(
61034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower          labels=labels, weights=weights, name=keys.LABEL_MEAN)
61134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      metric_ops = {
61234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower          # Estimator already adds a metric for loss.
613abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower          _summary_key(self._name, keys.LOSS_MEAN):
614edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir              metrics_lib.mean(
615edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir                  unweighted_loss, weights=weights, name=keys.LOSS_MEAN),
616abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower          _summary_key(self._name, keys.ACCURACY):
617edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir              metrics_lib.accuracy(
618edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir                  labels=labels,
619edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir                  predictions=class_ids,
620edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir                  weights=weights,
621edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir                  name=keys.ACCURACY),
622abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower          _summary_key(self._name, keys.PREDICTION_MEAN):
623edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir              _predictions_mean(
624edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir                  predictions=logistic,
625edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir                  weights=weights,
626edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir                  name=keys.PREDICTION_MEAN),
627abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower          _summary_key(self._name, keys.LABEL_MEAN):
628edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir              labels_mean,
629abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower          _summary_key(self._name, keys.ACCURACY_BASELINE):
630edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir              _accuracy_baseline(labels_mean),
631abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower          _summary_key(self._name, keys.AUC):
632edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir              _auc(
633edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir                  labels=labels,
634edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir                  predictions=logistic,
635edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir                  weights=weights,
636edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir                  name=keys.AUC),
637abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower          _summary_key(self._name, keys.AUC_PR):
638edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir              _auc(
639edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir                  labels=labels,
640edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir                  predictions=logistic,
641edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir                  weights=weights,
642edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir                  curve='PR',
643edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir                  name=keys.AUC_PR)
64434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      }
64534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      for threshold in self._thresholds:
64634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        accuracy_key = keys.ACCURACY_AT_THRESHOLD % threshold
647abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower        metric_ops[_summary_key(self._name,
6480fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower                                accuracy_key)] = _accuracy_at_threshold(
6490fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower                                    labels=labels,
6500fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower                                    predictions=logistic,
6510fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower                                    weights=weights,
6520fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower                                    threshold=threshold,
6530fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower                                    name=accuracy_key)
65434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        # Precision for positive examples.
65534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        precision_key = keys.PRECISION_AT_THRESHOLD % threshold
656abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower        metric_ops[_summary_key(self._name,
6570fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower                                precision_key)] = _precision_at_threshold(
6580fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower                                    labels=labels,
6590fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower                                    predictions=logistic,
6600fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower                                    weights=weights,
6610fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower                                    threshold=threshold,
6620fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower                                    name=precision_key)
66334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        # Recall for positive examples.
66434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        recall_key = keys.RECALL_AT_THRESHOLD % threshold
665abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower        metric_ops[_summary_key(self._name,
6660fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower                                recall_key)] = _recall_at_threshold(
6670fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower                                    labels=labels,
6680fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower                                    predictions=logistic,
6690fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower                                    weights=weights,
6700fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower                                    threshold=threshold,
6710fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower                                    name=recall_key)
67234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      return metric_ops
67334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
6748d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower  def create_loss(self, features, mode, logits, labels):
6758d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower    """See `Head`."""
6768d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower    del mode, features  # Unused for this head.
6778d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower    labels = _check_labels(_maybe_expand_dim(labels), self.logits_dimension)
6788d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower    if self._label_vocabulary is not None:
6798d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower      labels = lookup_ops.index_table_from_tensor(
6808d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower          vocabulary_list=tuple(self._label_vocabulary),
6818d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower          name='class_id_lookup').lookup(labels)
6828d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower    labels = math_ops.to_float(labels)
6838d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower    labels = _assert_range(labels, 2)
6848d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower    return LossAndLabels(
6858d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower        unweighted_loss=nn.sigmoid_cross_entropy_with_logits(
6868d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower            labels=labels, logits=logits),
6878d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower        processed_labels=labels)
6888d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower
68934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  def create_estimator_spec(
69034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      self, features, mode, logits, labels=None, train_op_fn=None):
69134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    """See `Head`."""
692169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir    # Predict.
693169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir    with ops.name_scope('head'):
694169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir      with ops.name_scope(None, 'predictions', (logits,)):
695169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir        pred_keys = prediction_keys.PredictionKeys
696169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir        logits = _check_logits(logits, self.logits_dimension)
697169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir        logistic = math_ops.sigmoid(logits, name=pred_keys.LOGISTIC)
698169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir        two_class_logits = array_ops.concat(
699169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir            (array_ops.zeros_like(logits), logits), 1, name='two_class_logits')
700169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir        scores = nn.softmax(two_class_logits, name=pred_keys.PROBABILITIES)
701169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir        class_ids = array_ops.reshape(
702169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir            math_ops.argmax(two_class_logits, axis=1), (-1, 1), name='classes')
703169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir        if self._label_vocabulary:
704169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir          table = lookup_ops.index_to_string_table_from_tensor(
705169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir              vocabulary_list=self._label_vocabulary,
706169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir              name='class_string_lookup')
707169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir          classes = table.lookup(class_ids)
708169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir        else:
709169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir          classes = string_ops.as_string(class_ids, name='str_classes')
710169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir        predictions = {
711169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir            pred_keys.LOGITS: logits,
712169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir            pred_keys.LOGISTIC: logistic,
713169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir            pred_keys.PROBABILITIES: scores,
714169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir            pred_keys.CLASS_IDS: class_ids,
715169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir            pred_keys.CLASSES: classes,
716169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir        }
71734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      if mode == model_fn.ModeKeys.PREDICT:
71802ebc44be9a660a54793973110ae26cf948ffceaMustafa Ispir        batch_size = array_ops.shape(logistic)[0]
71902ebc44be9a660a54793973110ae26cf948ffceaMustafa Ispir        export_class_list = self._label_vocabulary
72002ebc44be9a660a54793973110ae26cf948ffceaMustafa Ispir        if not export_class_list:
72102ebc44be9a660a54793973110ae26cf948ffceaMustafa Ispir          export_class_list = string_ops.as_string([0, 1])
72202ebc44be9a660a54793973110ae26cf948ffceaMustafa Ispir        export_output_classes = array_ops.tile(
72302ebc44be9a660a54793973110ae26cf948ffceaMustafa Ispir            input=array_ops.expand_dims(input=export_class_list, axis=0),
72402ebc44be9a660a54793973110ae26cf948ffceaMustafa Ispir            multiples=[batch_size, 1])
72502ebc44be9a660a54793973110ae26cf948ffceaMustafa Ispir        classifier_output = export_output.ClassificationOutput(
72602ebc44be9a660a54793973110ae26cf948ffceaMustafa Ispir            scores=scores,
72702ebc44be9a660a54793973110ae26cf948ffceaMustafa Ispir            # `ClassificationOutput` requires string classes.
72802ebc44be9a660a54793973110ae26cf948ffceaMustafa Ispir            classes=export_output_classes)
72934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        return model_fn.EstimatorSpec(
73034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower            mode=model_fn.ModeKeys.PREDICT,
73134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower            predictions=predictions,
732edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir            export_outputs={
7334db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel                _DEFAULT_SERVING_KEY: classifier_output,
7344db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel                _CLASSIFY_SERVING_KEY: classifier_output,
7354db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel                _REGRESS_SERVING_KEY: export_output.RegressionOutput(
7364db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel                    value=logistic),
7374db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel                _PREDICT_SERVING_KEY: export_output.PredictOutput(predictions)
738edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir            })
73934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
74034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      # Eval.
7418d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower      unweighted_loss, processed_labels = self.create_loss(
7428d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower          features=features, mode=mode, logits=logits, labels=labels)
743d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispir      weights = _weights(features, self._weight_column)
74434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      training_loss = losses.compute_weighted_loss(
74534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower          unweighted_loss, weights=weights, reduction=losses.Reduction.SUM)
74634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      if mode == model_fn.ModeKeys.EVAL:
74734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        return model_fn.EstimatorSpec(
74834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower            mode=model_fn.ModeKeys.EVAL,
74934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower            predictions=predictions,
75034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower            loss=training_loss,
75134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower            eval_metric_ops=self._eval_metric_ops(
7528d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower                labels=processed_labels,
75334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower                logits=logits,
75434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower                logistic=logistic,
75534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower                scores=scores,
756edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir                class_ids=class_ids,
75734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower                unweighted_loss=unweighted_loss,
75834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower                weights=weights))
75934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
76034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      # Train.
76134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      if train_op_fn is None:
76234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        raise ValueError('train_op_fn can not be None.')
763169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir    with ops.name_scope(''):
7640fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower      summary.scalar(
765abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower          _summary_key(self._name, metric_keys.MetricKeys.LOSS),
7660fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower          training_loss)
7670fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower      summary.scalar(
768abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower          _summary_key(self._name, metric_keys.MetricKeys.LOSS_MEAN),
7690fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower          losses.compute_weighted_loss(
7700fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower              unweighted_loss, weights=weights,
7710fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower              reduction=losses.Reduction.MEAN))
772169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir    return model_fn.EstimatorSpec(
773169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir        mode=model_fn.ModeKeys.TRAIN,
774169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir        predictions=predictions,
775169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir        loss=training_loss,
776169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir        train_op=train_op_fn(training_loss))
77734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
77834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
779d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispirdef _regression_head_with_mean_squared_error_loss(weight_column=None,
7800fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower                                                  label_dimension=1,
781abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower                                                  name=None):
78234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  """Creates a `_Head` for regression using the mean squared loss.
78334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
78434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  Args:
785d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispir    weight_column: A string or a `_NumericColumn` created by
786d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispir      `tf.feature_column.numeric_column` defining feature column representing
78734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      weights. It is used to down weight or boost examples during training. It
78834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      will be multiplied by the loss of the example.
78934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    label_dimension: Number of regression labels per example. This is the size
79034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      of the last dimension of the labels `Tensor` (typically, this has shape
79134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      `[batch_size, label_dimension]`).
792abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower    name: name of the head. If provided, summary and metrics keys will be
793abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower      suffixed by `"/" + name`.
79434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
79534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  Returns:
79634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    An instance of `_Head` for linear regression.
79734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  """
79834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  return _RegressionHeadWithMeanSquaredErrorLoss(
7990fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower      weight_column=weight_column,
8000fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower      label_dimension=label_dimension,
801abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower      name=name)
80234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
80334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
80434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerclass _RegressionHeadWithMeanSquaredErrorLoss(_Head):
80534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  """`Head` for regression using the mean squared loss."""
80634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
807abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower  def __init__(self, label_dimension, weight_column=None, name=None):
808d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispir    """`Head` for regression."""
80934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    if label_dimension < 1:
81034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      raise ValueError('Invalid label_dimension %s.' % label_dimension)
81134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    self._logits_dimension = label_dimension
812d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispir    self._weight_column = weight_column
813abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower    self._name = name
814abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower
815abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower  @property
816abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower  def name(self):
817abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower    return self._name
81834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
81934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  @property
82034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  def logits_dimension(self):
82134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    return self._logits_dimension
82234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
8238d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower  def create_loss(self, features, mode, logits, labels):
8248d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower    """See `Head`."""
8258d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower    del mode, features  # Unused for this head.
8268d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower    labels = _check_labels(
8278d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower        _maybe_expand_dim(math_ops.to_float(labels)), self._logits_dimension)
8288d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower    return LossAndLabels(
8298d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower        unweighted_loss=losses.mean_squared_error(
8308d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower            labels=labels, predictions=logits, reduction=losses.Reduction.NONE),
8318d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower        processed_labels=labels)
8328d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower
83334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower  def create_estimator_spec(
83434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      self, features, mode, logits, labels=None, train_op_fn=None):
83534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower    """See `Head`."""
836169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir    # Predict.
837169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir    with ops.name_scope('head'):
83834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      logits = _check_logits(logits, self._logits_dimension)
83934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      predictions = {prediction_keys.PredictionKeys.PREDICTIONS: logits}
84034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      if mode == model_fn.ModeKeys.PREDICT:
8414db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel        regression_output = export_output.RegressionOutput(value=logits)
84234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        return model_fn.EstimatorSpec(
84334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower            mode=model_fn.ModeKeys.PREDICT,
84434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower            predictions=predictions,
8454db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel            export_outputs={
8464db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel                _DEFAULT_SERVING_KEY: regression_output,
8474db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel                _REGRESS_SERVING_KEY: regression_output,
8484db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel                _PREDICT_SERVING_KEY: export_output.PredictOutput(predictions)
8494db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel            })
85034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
85134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      # Eval.
8528d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower      unweighted_loss, _ = self.create_loss(
8538d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower          features=features, mode=mode, logits=logits, labels=labels)
854d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispir      weights = _weights(features, self._weight_column)
85534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      training_loss = losses.compute_weighted_loss(
85634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower          unweighted_loss, weights=weights, reduction=losses.Reduction.SUM)
85734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      if mode == model_fn.ModeKeys.EVAL:
85834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        # Estimator already adds a metric for loss.
85934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        eval_metric_ops = {
86034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower            metric_keys.MetricKeys.LOSS_MEAN: metrics_lib.mean(
86134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower                unweighted_loss, weights=weights)
86234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        }
86334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        return model_fn.EstimatorSpec(
86434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower            mode=model_fn.ModeKeys.EVAL,
86534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower            predictions=predictions,
86634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower            loss=training_loss,
86734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower            eval_metric_ops=eval_metric_ops)
86834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower
86934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      # Train.
87034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower      if train_op_fn is None:
87134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower        raise ValueError('train_op_fn can not be None.')
872169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir    with ops.name_scope(''):
8730fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower      summary.scalar(
874abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower          _summary_key(self._name, metric_keys.MetricKeys.LOSS),
8750fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower          training_loss)
8760fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower      summary.scalar(
877abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower          _summary_key(self._name, metric_keys.MetricKeys.LOSS_MEAN),
8780fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower          losses.compute_weighted_loss(
8790fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower              unweighted_loss, weights=weights,
8800fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower              reduction=losses.Reduction.MEAN))
881169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir    return model_fn.EstimatorSpec(
882169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir        mode=model_fn.ModeKeys.TRAIN,
883169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir        predictions=predictions,
884169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir        loss=training_loss,
885169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir        train_op=train_op_fn(training_loss))
886edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir
887edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir
888edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispirdef _assert_range(labels, n_classes):
889169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir  with ops.name_scope(None, 'assert_range', (labels,)):
890169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir    assert_less = check_ops.assert_less(
891169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir        labels,
892169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir        ops.convert_to_tensor(n_classes, dtype=labels.dtype),
893169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir        message='Label IDs must < n_classes')
894169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir    assert_greater = check_ops.assert_non_negative(
895169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir        labels, message='Label IDs must >= 0')
896169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir    with ops.control_dependencies((assert_less, assert_greater)):
897169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir      return array_ops.identity(labels)
898d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispir
899d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispir
900d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispirdef _weights(features, weight_column):
901d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispir  """Fetches weights from features."""
902169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir  with ops.name_scope(None, 'weights', values=features.values()):
903169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir    if weight_column is None:
904169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir      return 1.
905169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir    if isinstance(weight_column, six.string_types):
906169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir      weight_column = feature_column_lib.numeric_column(key=weight_column)
907169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir    if not isinstance(weight_column, feature_column_lib._NumericColumn):  # pylint: disable=protected-access
908169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir      raise TypeError('Weight column must be either a string or _NumericColumn.'
909169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir                      ' Given type: {}.'.format(type(weight_column)))
910169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir    weights = weight_column._get_dense_tensor(  # pylint: disable=protected-access
911169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir        feature_column_lib._LazyBuilder(features))  # pylint: disable=protected-access
912169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir    if not (weights.dtype.is_floating or weights.dtype.is_integer):
913169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir      raise ValueError('Weight column should be castable to float. '
914169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir                       'Given dtype: {}'.format(weights.dtype))
915169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir    weights = _maybe_expand_dim(math_ops.to_float(weights, name='weights'))
916169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir    return weights
917