head.py revision 2dab9fd3c89f47dbb0b5f4368084cebb56e03a09
134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower# 334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower# Licensed under the Apache License, Version 2.0 (the "License"); 434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower# you may not use this file except in compliance with the License. 534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower# You may obtain a copy of the License at 634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower# 734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower# http://www.apache.org/licenses/LICENSE-2.0 834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower# 934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower# Unless required by applicable law or agreed to in writing, software 1034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower# distributed under the License is distributed on an "AS IS" BASIS, 1134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 1234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower# See the License for the specific language governing permissions and 1334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower# limitations under the License. 1434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower# ============================================================================== 1534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower"""Abstractions for the head(s) of a model.""" 1634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 1734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerfrom __future__ import absolute_import 1834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerfrom __future__ import division 1934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerfrom __future__ import print_function 2034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 2134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerimport abc 22e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Caiimport collections 2334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 2434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerimport six 2534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 2634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerfrom tensorflow.python.estimator import model_fn 2734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerfrom tensorflow.python.estimator.canned import metric_keys 2834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerfrom tensorflow.python.estimator.canned import prediction_keys 2934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerfrom tensorflow.python.estimator.export import export_output 30d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispirfrom tensorflow.python.feature_column import feature_column as feature_column_lib 3134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerfrom tensorflow.python.framework import dtypes 3234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerfrom tensorflow.python.framework import ops 3334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerfrom tensorflow.python.framework import sparse_tensor 3434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerfrom tensorflow.python.ops import array_ops 3534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerfrom tensorflow.python.ops import check_ops 36ce21f5e372360451528c9716a743b65308421423Mustafa Ispirfrom tensorflow.python.ops import lookup_ops 3734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerfrom tensorflow.python.ops import math_ops 3834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerfrom tensorflow.python.ops import metrics as metrics_lib 3934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerfrom tensorflow.python.ops import nn 4034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerfrom tensorflow.python.ops import string_ops 4134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerfrom tensorflow.python.ops import weights_broadcast_ops 4234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerfrom tensorflow.python.ops.losses import losses 4334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerfrom tensorflow.python.platform import tf_logging as logging 4402ebc44be9a660a54793973110ae26cf948ffceaMustafa Ispirfrom tensorflow.python.saved_model import signature_constants 45169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispirfrom tensorflow.python.summary import summary 4602ebc44be9a660a54793973110ae26cf948ffceaMustafa Ispir 4702ebc44be9a660a54793973110ae26cf948ffceaMustafa Ispir_DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY 4834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 494db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel# The above default is defined by TF Serving, but these next three are just 504db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel# a local convention without any special meaning. 514db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel_CLASSIFY_SERVING_KEY = 'classification' 524db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel_REGRESS_SERVING_KEY = 'regression' 534db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel_PREDICT_SERVING_KEY = 'predict' 544db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel 5534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 568d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlowerLossAndLabels = collections.namedtuple('LossAndLabels', 578d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower ['unweighted_loss', 'processed_labels']) 588d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower 598d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower 600fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlowerdef _summary_key(head_name, val): 610fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower return '%s/%s' % (val, head_name) if head_name else val 620fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower 630fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower 6434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerclass _Head(object): 6534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower """Interface for the head/top of a model. 6634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 6734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower Given logits (or output of a hidden layer), a Head knows how to compute 6834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower predictions, loss, train_op, metrics and export outputs. It is meant to: 6934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 7034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 1. Simplify writing model_fn and to make model_fn more configurable 7134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 2. Support wide range of machine learning models. Since most heads can work 7234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower with logits, they can support DNN, RNN, Wide, Wide&Deep, 7334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower Global objectives, Gradient boosted trees and many other types 7434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower of machine learning models. 7534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 7634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower Common usage: 7734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower Here is simplified model_fn to build a DNN regression model. 7834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower ```python 7934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower def _my_dnn_model_fn(features, labels, mode, params, config=None): 8034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower # Optionally your callers can pass head to model_fn as a param. 8134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower head = tf.contrib.learn.regression_head(...) 8234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower input = tf.contrib.layers.input_from_feature_columns(features, ...) 8334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower last_hidden_layer_out = tf.contrib.layers.stack( 8434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower input, tf.contrib.layers.fully_connected, [1000, 500]) 8534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower logits = tf.contrib.layers.fully_connected( 8634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower last_hidden_layer_out, head.logits_dimension, activation_fn=None) 8734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 8834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower def _train_op_fn(loss): 8934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower return optimizer.minimize(loss) 9034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 9134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower return head.create_estimator_spec( 9234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower features=features, 9334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower labels=labels, 9434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower mode=mode, 9534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower logits=logits, 9634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower train_op_fn=_train_op_fn) 9734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower ``` 9834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 9934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower There are cases where computing and applying gradients can not be meaningfully 10034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower captured with train_op_fn we support (for example, with sync optimizer). In 10134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower such case, you can take the responsibility on your own. Here is a common 10234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower use case, 10334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower ```python 10434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower estimator_spec = head.create_estimator_spec( 10534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower features=features, 10634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower labels=labels, 10734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower mode=mode, 10834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower logits=logits, 10934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower train_op_fn=tf.contrib.learn.no_op_train_fn) 11034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower if mode == model_fn.ModeKeys.TRAIN: 11134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower optimizer = ... 11234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower sync = tf.train.SyncReplicasOptimizer(opt=optimizer, ...) 11334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower update_op = tf.contrib.layers.optimize_loss(optimizer=sync, 11434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower loss=estimator_spec.loss, ...) 11534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower hooks = [sync.make_session_run_hook(is_chief)] 11634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower ... upate train_op and hooks in EstimatorSpec and return 11734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower ``` 11834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower """ 11934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower __metaclass__ = abc.ABCMeta 12034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 12134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower @abc.abstractproperty 122abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower def name(self): 123abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower """The name of this head. 124abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower 125abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower Returns: 126abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower A string. 127abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower """ 128abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower raise NotImplementedError('Calling an abstract method.') 129abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower 130abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower @abc.abstractproperty 13134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower def logits_dimension(self): 13234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower """Size of the last dimension of the logits `Tensor`. 13334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 13434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower Typically, logits is of shape `[batch_size, logits_dimension]`. 13534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 13634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower Returns: 13734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower The expected size of the `logits` tensor. 13834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower """ 13934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower raise NotImplementedError('Calling an abstract method.') 14034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 14134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower @abc.abstractmethod 1428d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower def create_loss(self, features, mode, logits, labels): 1438d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower """Returns a loss Tensor from provided logits. 1448d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower 1458d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower This function is designed to be used by framework developers. Almost all 1468d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower users should use create_estimator_spec(), which calls this internally. 1478d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower `mode` and `features` are most likely not used, but some Head 1488d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower implementations may require them. 1498d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower 1508d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower Args: 1518d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower features: Input `dict` of `Tensor` objects. 1528d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower mode: Estimator's `ModeKeys`. 1538d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower logits: logits `Tensor` to be used for loss construction. 1542dab9fd3c89f47dbb0b5f4368084cebb56e03a09A. Unique TensorFlower labels: Labels `Tensor`, or `dict` of same. 1558d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower 1568d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower Returns: 1578d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower A LossAndLabels that contains the `Tensor` representing the loss and 1588d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower possibly processed labels (e.g. vocabulary lookup, shape manipulation, 1598d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower etc.), to be extendable in the future. 1608d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower """ 1618d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower raise NotImplementedError('Calling an abstract method.') 1628d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower 1638d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower @abc.abstractmethod 16434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower def create_estimator_spec( 16534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower self, features, mode, logits, labels=None, train_op_fn=None): 16634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower """Returns `EstimatorSpec` that a model_fn can return. 16734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 16834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower Please note that, 16934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower + All args must be passed via name. 17034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 17134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower Args: 17234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower features: Input `dict` of `Tensor` objects. 17334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower mode: Estimator's `ModeKeys`. 17434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower logits: logits `Tensor` to be used by the head. 17534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower labels: Labels `Tensor`, or `dict` of same. 17634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower train_op_fn: Function that takes a scalar loss `Tensor` and returns an op 17734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower to optimize the model with the loss. This is used in TRAIN mode and 17834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower must not be None. None is allowed in other modes. If you want to 17934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower optimize loss yourself you can pass `no_op_train_fn` and then use 18034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower EstimatorSpec.loss to compute and apply gradients. 18134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 18234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower Returns: 18334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower `EstimatorSpec`. 18434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower """ 18534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower raise NotImplementedError('Calling an abstract method.') 18634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 18734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 18862cf561f1d32abf4f5b7fbdee6d106389994ff05Jianwei Xiedef _maybe_expand_dim(tensor): 18962cf561f1d32abf4f5b7fbdee6d106389994ff05Jianwei Xie """Expand the dim of `tensor` with static rank 1.""" 19062cf561f1d32abf4f5b7fbdee6d106389994ff05Jianwei Xie with ops.name_scope(None, 'maybe_expand_dim', (tensor,)): 19162cf561f1d32abf4f5b7fbdee6d106389994ff05Jianwei Xie tensor = sparse_tensor.convert_to_tensor_or_sparse_tensor(tensor) 19262cf561f1d32abf4f5b7fbdee6d106389994ff05Jianwei Xie if isinstance(tensor, sparse_tensor.SparseTensor): 19362cf561f1d32abf4f5b7fbdee6d106389994ff05Jianwei Xie raise ValueError('SparseTensor labels are not supported.') 19462cf561f1d32abf4f5b7fbdee6d106389994ff05Jianwei Xie static_shape = tensor.shape 19562cf561f1d32abf4f5b7fbdee6d106389994ff05Jianwei Xie if static_shape is None: 19662cf561f1d32abf4f5b7fbdee6d106389994ff05Jianwei Xie return tensor 19762cf561f1d32abf4f5b7fbdee6d106389994ff05Jianwei Xie 19862cf561f1d32abf4f5b7fbdee6d106389994ff05Jianwei Xie return (array_ops.expand_dims(tensor, -1) if static_shape.ndims == 1 19962cf561f1d32abf4f5b7fbdee6d106389994ff05Jianwei Xie else tensor) 20062cf561f1d32abf4f5b7fbdee6d106389994ff05Jianwei Xie 20162cf561f1d32abf4f5b7fbdee6d106389994ff05Jianwei Xie 20234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerdef _check_labels(labels, expected_labels_dimension): 20334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower """Check labels type and shape.""" 20434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower with ops.name_scope(None, 'labels', (labels,)) as scope: 20534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower labels = sparse_tensor.convert_to_tensor_or_sparse_tensor(labels) 20634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower if isinstance(labels, sparse_tensor.SparseTensor): 20734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower raise ValueError('SparseTensor labels are not supported.') 20834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower labels_shape = array_ops.shape(labels) 20934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower err_msg = 'labels shape must be [batch_size, {}]'.format( 21034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower expected_labels_dimension) 21134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower assert_rank = check_ops.assert_rank(labels, 2, message=err_msg) 21234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower with ops.control_dependencies([assert_rank]): 21334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower static_shape = labels.shape 21434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower if static_shape is not None: 21534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower dim1 = static_shape[1] 21634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower if (dim1 is not None) and (dim1 != expected_labels_dimension): 21734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower raise ValueError( 2186e8d0c632dea30758c7cc343decdf8ab7956e59dA. Unique TensorFlower 'Mismatched label shape. ' 2196e8d0c632dea30758c7cc343decdf8ab7956e59dA. Unique TensorFlower 'Classifier configured with n_classes=%s. Received %s. ' 2206e8d0c632dea30758c7cc343decdf8ab7956e59dA. Unique TensorFlower 'Suggested Fix: check your n_classes argument to the estimator ' 2216e8d0c632dea30758c7cc343decdf8ab7956e59dA. Unique TensorFlower 'and/or the shape of your label.' % 2226e8d0c632dea30758c7cc343decdf8ab7956e59dA. Unique TensorFlower (expected_labels_dimension, dim1)) 22334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower assert_dimension = check_ops.assert_equal( 22434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower expected_labels_dimension, labels_shape[1], message=err_msg) 22534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower with ops.control_dependencies([assert_dimension]): 22634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower return array_ops.identity(labels, name=scope) 22734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 22834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 22934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerdef _check_logits(logits, expected_logits_dimension): 23034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower """Check logits type and shape.""" 23134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower with ops.name_scope(None, 'logits', (logits,)) as scope: 23234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower logits = math_ops.to_float(logits) 23334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower logits_shape = array_ops.shape(logits) 23434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower assert_rank = check_ops.assert_rank( 23534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower logits, 2, data=[logits_shape], 23634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower message='logits shape must be [batch_size, logits_dimension]') 23734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower with ops.control_dependencies([assert_rank]): 23834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower static_shape = logits.shape 23934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower if static_shape is not None: 24034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower dim1 = static_shape[1] 24134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower if (dim1 is not None) and (dim1 != expected_logits_dimension): 24234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower raise ValueError( 24334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 'logits shape must be [batch_size, logits_dimension], got %s.' % 24434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower (static_shape,)) 24534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower assert_dimension = check_ops.assert_equal( 24634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower expected_logits_dimension, logits_shape[1], data=[logits_shape], 24734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower message='logits shape must be [batch_size, logits_dimension]') 24834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower with ops.control_dependencies([assert_dimension]): 24934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower return array_ops.identity(logits, name=scope) 25034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 25134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 25234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerdef _indicator_labels_mean(labels, weights=None, name=None): 25334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower with ops.name_scope(name, 'labels_mean', (labels, weights)) as scope: 25434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower labels = math_ops.to_float(labels, name='labels') 25534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower if weights is not None: 25634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower weights = weights_broadcast_ops.broadcast_weights(weights, labels) 25734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower return metrics_lib.mean(labels, weights=weights, name=scope) 25834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 25934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 26034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerdef _accuracy_baseline(labels_mean): 26134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower """Return accuracy baseline based on labels mean. 26234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 26334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower This is the best the model could do by always predicting one class. 26434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 26534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower Args: 26634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower labels_mean: Tuple of value and update op. 26734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 26834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower Returns: 26934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower Tuple of value and update op. 27034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower """ 27134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower with ops.name_scope(None, 'accuracy_baseline', labels_mean): 27234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower value, update_op = labels_mean 27334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower return ( 27434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower math_ops.maximum(value, 1. - value, name='value'), 27534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower math_ops.maximum(update_op, 1 - update_op, name='update_op')) 27634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 27734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 27834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerdef _predictions_mean(predictions, weights=None, name=None): 27934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower with ops.name_scope( 28034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower name, 'predictions_mean', (predictions, weights)) as scope: 28134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower predictions = math_ops.to_float(predictions, name='predictions') 28234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower if weights is not None: 28334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower weights = weights_broadcast_ops.broadcast_weights(weights, predictions) 28434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower return metrics_lib.mean(predictions, weights=weights, name=scope) 28534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 28634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 28734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerdef _auc(labels, predictions, weights=None, curve='ROC', name=None): 28834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower with ops.name_scope(name, 'auc', (predictions, labels, weights)) as scope: 28934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower predictions = math_ops.to_float(predictions, name='predictions') 29034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower if labels.dtype.base_dtype != dtypes.bool: 29134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower logging.warning('Casting %s labels to bool.', labels.dtype) 29234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower labels = math_ops.cast(labels, dtypes.bool) 29334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower if weights is not None: 29434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower weights = weights_broadcast_ops.broadcast_weights(weights, predictions) 29534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower return metrics_lib.auc( 29634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower labels=labels, predictions=predictions, weights=weights, curve=curve, 29734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower name=scope) 29834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 29934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 30034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerdef _accuracy_at_threshold(labels, predictions, weights, threshold, name=None): 30134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower with ops.name_scope( 30234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower name, 'accuracy_at_%s' % threshold, 30334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower (predictions, labels, weights, threshold)) as scope: 30434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower threshold_predictions = math_ops.to_float( 30534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower math_ops.greater_equal(predictions, threshold)) 30634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower return metrics_lib.accuracy( 30734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower labels=labels, predictions=threshold_predictions, weights=weights, 30834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower name=scope) 30934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 31034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 31134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerdef _precision_at_threshold(labels, predictions, weights, threshold, name=None): 31234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower with ops.name_scope( 31334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower name, 'precision_at_%s' % threshold, 31434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower (predictions, labels, weights, threshold)) as scope: 31534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower precision_tensor, update_op = metrics_lib.precision_at_thresholds( 31634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower labels=labels, predictions=predictions, thresholds=(threshold,), 31734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower weights=weights, name=scope) 31834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower return array_ops.squeeze(precision_tensor), array_ops.squeeze(update_op) 31934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 32034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 32134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerdef _recall_at_threshold(labels, predictions, weights, threshold, name=None): 32234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower with ops.name_scope( 32334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower name, 'recall_at_%s' % threshold, 32434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower (predictions, labels, weights, threshold)) as scope: 32534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower precision_tensor, update_op = metrics_lib.recall_at_thresholds( 32634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower labels=labels, predictions=predictions, thresholds=(threshold,), 32734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower weights=weights, name=scope) 32834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower return array_ops.squeeze(precision_tensor), array_ops.squeeze(update_op) 32934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 33034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 331ce21f5e372360451528c9716a743b65308421423Mustafa Ispirdef _multi_class_head_with_softmax_cross_entropy_loss(n_classes, 332d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispir weight_column=None, 3330fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower label_vocabulary=None, 334abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower name=None): 33534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower """Creates a '_Head' for multi class classification. 33634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 33734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower This head expects to be fed integer labels specifying the class index. 33834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 33934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower Args: 34034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower n_classes: Number of classes, must be greater than 2 (for 2 classes, use 34134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower `_BinaryLogisticHeadWithSigmoidCrossEntropyLoss`). 342d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispir weight_column: A string or a `_NumericColumn` created by 343d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispir `tf.feature_column.numeric_column` defining feature column representing 34434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower weights. It is used to down weight or boost examples during training. It 34534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower will be multiplied by the loss of the example. 346ce21f5e372360451528c9716a743b65308421423Mustafa Ispir label_vocabulary: A list of strings represents possible label values. If it 347ce21f5e372360451528c9716a743b65308421423Mustafa Ispir is not given, that means labels are already encoded as integer within 348ce21f5e372360451528c9716a743b65308421423Mustafa Ispir [0, n_classes). If given, labels must be string type and have any value in 349ce21f5e372360451528c9716a743b65308421423Mustafa Ispir `label_vocabulary`. Also there will be errors if vocabulary is not 350ce21f5e372360451528c9716a743b65308421423Mustafa Ispir provided and labels are string. 351abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower name: name of the head. If provided, summary and metrics keys will be 352abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower suffixed by `"/" + name`. 35334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 35434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower Returns: 3550815de21239955e346b562e899640649c8d2b9cbBenoit Steiner An instance of `_Head` for multi class classification. 35634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 35734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower Raises: 35834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower ValueError: if `n_classes`, `metric_class_ids` or `label_keys` is invalid. 35934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower """ 360edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir if label_vocabulary is not None and not isinstance(label_vocabulary, 361edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir (list, tuple)): 362ce21f5e372360451528c9716a743b65308421423Mustafa Ispir raise ValueError('label_vocabulary should be a list. Given type: {}'.format( 363ce21f5e372360451528c9716a743b65308421423Mustafa Ispir type(label_vocabulary))) 364ce21f5e372360451528c9716a743b65308421423Mustafa Ispir 365d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispir return _MultiClassHeadWithSoftmaxCrossEntropyLoss(n_classes, weight_column, 366abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower label_vocabulary, name) 36734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 36834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 36934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerclass _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head): 37034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower """See `_multi_class_head_with_softmax_cross_entropy_loss`.""" 37134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 3720fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower def __init__(self, 3730fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower n_classes, 3740fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower weight_column=None, 3750fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower label_vocabulary=None, 376abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower name=None): 37734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower if (n_classes is None) or (n_classes <= 2): 37834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower raise ValueError('n_classes must be > 2: %s.' % n_classes) 37934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower self._n_classes = n_classes 380d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispir self._weight_column = weight_column 381ce21f5e372360451528c9716a743b65308421423Mustafa Ispir self._label_vocabulary = label_vocabulary 382abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower self._name = name 383abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower 384abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower @property 385abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower def name(self): 386abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower return self._name 38734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 38834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower @property 38934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower def logits_dimension(self): 39034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower return self._n_classes 39134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 39234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower def _eval_metric_ops(self, labels, probabilities, logits, 39334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower class_ids, weights, unweighted_loss): 39434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower """Returns the Eval metric ops.""" 39534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower with ops.name_scope( 39634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower None, 'metrics', 39734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower (labels, probabilities, logits, class_ids, weights, unweighted_loss)): 39834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower keys = metric_keys.MetricKeys 39934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower metric_ops = { 40034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower # Estimator already adds a metric for loss. 40134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower # TODO(xiejw): Any other metrics? 402abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower _summary_key(self._name, keys.LOSS_MEAN): 4030fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower metrics_lib.mean( 4040fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower unweighted_loss, weights=weights, name=keys.LOSS_MEAN), 405abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower _summary_key(self._name, keys.ACCURACY): 4060fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower metrics_lib.accuracy( 4070fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower labels=labels, 4080fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower predictions=class_ids, 4090fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower weights=weights, 4100fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower name=keys.ACCURACY), 41134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower } 41234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower return metric_ops 41334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 414ce21f5e372360451528c9716a743b65308421423Mustafa Ispir def _label_ids(self, labels): 415ce21f5e372360451528c9716a743b65308421423Mustafa Ispir """Converts labels to integer id space.""" 416ce21f5e372360451528c9716a743b65308421423Mustafa Ispir if self._label_vocabulary is None: 417ce21f5e372360451528c9716a743b65308421423Mustafa Ispir if not labels.dtype.is_integer: 418ce21f5e372360451528c9716a743b65308421423Mustafa Ispir raise ValueError('Labels dtype should be integer ' 419ce21f5e372360451528c9716a743b65308421423Mustafa Ispir 'Instead got %s.' % labels.dtype) 420ce21f5e372360451528c9716a743b65308421423Mustafa Ispir label_ids = labels 421ce21f5e372360451528c9716a743b65308421423Mustafa Ispir else: 422ce21f5e372360451528c9716a743b65308421423Mustafa Ispir if labels.dtype != dtypes.string: 423ce21f5e372360451528c9716a743b65308421423Mustafa Ispir raise ValueError('Labels dtype should be string if there is a ' 424ce21f5e372360451528c9716a743b65308421423Mustafa Ispir 'vocabulary. Instead got {}'.format(labels.dtype)) 425ce21f5e372360451528c9716a743b65308421423Mustafa Ispir label_ids = lookup_ops.index_table_from_tensor( 426ce21f5e372360451528c9716a743b65308421423Mustafa Ispir vocabulary_list=tuple(self._label_vocabulary), 427ce21f5e372360451528c9716a743b65308421423Mustafa Ispir name='class_id_lookup').lookup(labels) 428edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir return _assert_range(label_ids, self._n_classes) 429ce21f5e372360451528c9716a743b65308421423Mustafa Ispir 4308d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower def create_loss(self, features, mode, logits, labels): 4318d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower """See `Head`.""" 4328d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower del mode, features # Unused for this head. 4338d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower label_ids = self._label_ids(_check_labels(_maybe_expand_dim(labels), 1)) 4348d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower unweighted_loss = losses.sparse_softmax_cross_entropy( 4358d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower labels=label_ids, logits=logits, reduction=losses.Reduction.NONE) 4368d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower # Restore the squeezed dim, so unweighted_loss matches the weights shape. 4378d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower return LossAndLabels( 4388d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower unweighted_loss=array_ops.expand_dims(unweighted_loss, axis=(1,)), 4398d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower processed_labels=label_ids) 4408d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower 44134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower def create_estimator_spec( 44234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower self, features, mode, logits, labels=None, train_op_fn=None): 44334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower """See `Head`.""" 444169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir with ops.name_scope('head'): 44534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower logits = _check_logits(logits, self.logits_dimension) 44634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 44734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower # Predict. 44834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower pred_keys = prediction_keys.PredictionKeys 44934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower with ops.name_scope(None, 'predictions', (logits,)): 45034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower # class_ids's shape is [batch_size] 451ce21f5e372360451528c9716a743b65308421423Mustafa Ispir class_ids = math_ops.argmax(logits, 1, name=pred_keys.CLASS_IDS) 452ce21f5e372360451528c9716a743b65308421423Mustafa Ispir class_ids = array_ops.expand_dims(class_ids, axis=(1,)) 453ce21f5e372360451528c9716a743b65308421423Mustafa Ispir if self._label_vocabulary: 454ce21f5e372360451528c9716a743b65308421423Mustafa Ispir table = lookup_ops.index_to_string_table_from_tensor( 455ce21f5e372360451528c9716a743b65308421423Mustafa Ispir vocabulary_list=self._label_vocabulary, 456ce21f5e372360451528c9716a743b65308421423Mustafa Ispir name='class_string_lookup') 457ce21f5e372360451528c9716a743b65308421423Mustafa Ispir classes = table.lookup(class_ids) 458ce21f5e372360451528c9716a743b65308421423Mustafa Ispir else: 459ce21f5e372360451528c9716a743b65308421423Mustafa Ispir classes = string_ops.as_string(class_ids, name='str_classes') 460ce21f5e372360451528c9716a743b65308421423Mustafa Ispir 46134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower probabilities = nn.softmax(logits, name=pred_keys.PROBABILITIES) 46234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower predictions = { 46334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower pred_keys.LOGITS: logits, 46434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower pred_keys.PROBABILITIES: probabilities, 46534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower # Expand to [batch_size, 1] 466ce21f5e372360451528c9716a743b65308421423Mustafa Ispir pred_keys.CLASS_IDS: class_ids, 467ce21f5e372360451528c9716a743b65308421423Mustafa Ispir pred_keys.CLASSES: classes, 46834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower } 46934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower if mode == model_fn.ModeKeys.PREDICT: 47034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower batch_size = array_ops.shape(probabilities)[0] 471ce21f5e372360451528c9716a743b65308421423Mustafa Ispir export_class_list = self._label_vocabulary 472ce21f5e372360451528c9716a743b65308421423Mustafa Ispir if not export_class_list: 473ce21f5e372360451528c9716a743b65308421423Mustafa Ispir export_class_list = string_ops.as_string( 474ce21f5e372360451528c9716a743b65308421423Mustafa Ispir math_ops.range(self._n_classes)) 475ce21f5e372360451528c9716a743b65308421423Mustafa Ispir export_output_classes = array_ops.tile( 476ce21f5e372360451528c9716a743b65308421423Mustafa Ispir input=array_ops.expand_dims(input=export_class_list, axis=0), 47734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower multiples=[batch_size, 1]) 4784db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel classifier_output = export_output.ClassificationOutput( 4794db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel scores=probabilities, 4804db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel # `ClassificationOutput` requires string classes. 4814db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel classes=export_output_classes) 48234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower return model_fn.EstimatorSpec( 48334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower mode=model_fn.ModeKeys.PREDICT, 48434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower predictions=predictions, 485ce21f5e372360451528c9716a743b65308421423Mustafa Ispir export_outputs={ 4864db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel _DEFAULT_SERVING_KEY: classifier_output, 4874db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel _CLASSIFY_SERVING_KEY: classifier_output, 4884db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel _PREDICT_SERVING_KEY: export_output.PredictOutput(predictions) 489ce21f5e372360451528c9716a743b65308421423Mustafa Ispir }) 49034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 49134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower # Eval. 4928d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower unweighted_loss, label_ids = self.create_loss( 4938d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower features=features, mode=mode, logits=logits, labels=labels) 494d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispir weights = _weights(features, self._weight_column) 49534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower training_loss = losses.compute_weighted_loss( 49634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower unweighted_loss, weights=weights, reduction=losses.Reduction.SUM) 49734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower if mode == model_fn.ModeKeys.EVAL: 49834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower return model_fn.EstimatorSpec( 49934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower mode=model_fn.ModeKeys.EVAL, 50034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower predictions=predictions, 50134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower loss=training_loss, 50234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower eval_metric_ops=self._eval_metric_ops( 503ce21f5e372360451528c9716a743b65308421423Mustafa Ispir labels=label_ids, 50434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower probabilities=probabilities, 50534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower logits=logits, 50634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower class_ids=class_ids, 50734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower unweighted_loss=unweighted_loss, 50834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower weights=weights)) 50934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 51034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower # Train. 51134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower if train_op_fn is None: 51234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower raise ValueError('train_op_fn can not be None.') 513169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir with ops.name_scope(''): 5140fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower summary.scalar( 515abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower _summary_key(self._name, metric_keys.MetricKeys.LOSS), 5160fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower training_loss) 5170fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower summary.scalar( 518abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower _summary_key(self._name, metric_keys.MetricKeys.LOSS_MEAN), 5190fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower losses.compute_weighted_loss( 5200fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower unweighted_loss, weights=weights, 5210fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower reduction=losses.Reduction.MEAN)) 522169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir return model_fn.EstimatorSpec( 523169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir mode=model_fn.ModeKeys.TRAIN, 524169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir predictions=predictions, 525169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir loss=training_loss, 526169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir train_op=train_op_fn(training_loss)) 52734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 52834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 52934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerdef _binary_logistic_head_with_sigmoid_cross_entropy_loss( 530abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower weight_column=None, thresholds=None, label_vocabulary=None, name=None): 53134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower """Creates a `Head` for single label binary classification. 53234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 53334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower This head uses `sigmoid_cross_entropy_with_logits` loss. 53434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 53534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower This head expects to be fed float labels of shape `(batch_size, 1)`. 53634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 53734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower Args: 538d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispir weight_column: A string or a `_NumericColumn` created by 539d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispir `tf.feature_column.numeric_column` defining feature column representing 54034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower weights. It is used to down weight or boost examples during training. It 54134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower will be multiplied by the loss of the example. 54234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower thresholds: Iterable of floats in the range `(0, 1)`. For binary 54334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower classification metrics such as precision and recall, an eval metric is 54434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower generated for each threshold value. This threshold is applied to the 54534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower logistic values to determine the binary classification (i.e., above the 54634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower threshold is `true`, below is `false`. 547edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir label_vocabulary: A list of strings represents possible label values. If it 548edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir is not given, that means labels are already encoded within [0, 1]. If 549edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir given, labels must be string type and have any value in 550edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir `label_vocabulary`. Also there will be errors if vocabulary is not 551edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir provided and labels are string. 552abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower name: name of the head. If provided, summary and metrics keys will be 553abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower suffixed by `"/" + name`. 55434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 55534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower Returns: 55634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower An instance of `Head` for binary classification. 55734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 55834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower Raises: 55934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower ValueError: if `thresholds` contains a value outside of `(0, 1)`. 56034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower """ 56179099d67761b3e56d1c3764cd34f97401571a211A. Unique TensorFlower thresholds = tuple(thresholds) if thresholds else tuple() 562edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir if label_vocabulary is not None and not isinstance(label_vocabulary, 563edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir (list, tuple)): 564edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir raise ValueError('label_vocabulary should be a list. Given type: {}'.format( 565edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir type(label_vocabulary))) 566edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir 56734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower for threshold in thresholds: 56834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower if (threshold <= 0.0) or (threshold >= 1.0): 56934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower raise ValueError('thresholds not in (0, 1): %s.' % (thresholds,)) 57034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower return _BinaryLogisticHeadWithSigmoidCrossEntropyLoss( 571d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispir weight_column=weight_column, 572edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir thresholds=thresholds, 5730fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower label_vocabulary=label_vocabulary, 574abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower name=name) 57534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 57634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 57734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerclass _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head): 57834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower """See `_binary_logistic_head_with_sigmoid_cross_entropy_loss`.""" 57934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 5800fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower def __init__(self, 5810fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower weight_column=None, 5820fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower thresholds=None, 5830fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower label_vocabulary=None, 584abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower name=None): 585d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispir self._weight_column = weight_column 58679099d67761b3e56d1c3764cd34f97401571a211A. Unique TensorFlower self._thresholds = thresholds 587edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir self._label_vocabulary = label_vocabulary 588abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower self._name = name 589abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower 590abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower @property 591abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower def name(self): 592abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower return self._name 59334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 59434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower @property 59534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower def logits_dimension(self): 59634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower return 1 59734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 598edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir def _eval_metric_ops(self, 599edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir labels, 600edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir logits, 601edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir logistic, 602edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir scores, 603edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir class_ids, 604edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir unweighted_loss, 605edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir weights=None): 606edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir with ops.name_scope(None, 'metrics', (labels, logits, logistic, scores, 607edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir class_ids, unweighted_loss, weights)): 60834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower keys = metric_keys.MetricKeys 60934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower labels_mean = _indicator_labels_mean( 61034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower labels=labels, weights=weights, name=keys.LABEL_MEAN) 61134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower metric_ops = { 61234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower # Estimator already adds a metric for loss. 613abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower _summary_key(self._name, keys.LOSS_MEAN): 614edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir metrics_lib.mean( 615edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir unweighted_loss, weights=weights, name=keys.LOSS_MEAN), 616abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower _summary_key(self._name, keys.ACCURACY): 617edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir metrics_lib.accuracy( 618edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir labels=labels, 619edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir predictions=class_ids, 620edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir weights=weights, 621edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir name=keys.ACCURACY), 622abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower _summary_key(self._name, keys.PREDICTION_MEAN): 623edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir _predictions_mean( 624edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir predictions=logistic, 625edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir weights=weights, 626edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir name=keys.PREDICTION_MEAN), 627abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower _summary_key(self._name, keys.LABEL_MEAN): 628edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir labels_mean, 629abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower _summary_key(self._name, keys.ACCURACY_BASELINE): 630edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir _accuracy_baseline(labels_mean), 631abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower _summary_key(self._name, keys.AUC): 632edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir _auc( 633edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir labels=labels, 634edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir predictions=logistic, 635edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir weights=weights, 636edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir name=keys.AUC), 637abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower _summary_key(self._name, keys.AUC_PR): 638edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir _auc( 639edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir labels=labels, 640edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir predictions=logistic, 641edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir weights=weights, 642edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir curve='PR', 643edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir name=keys.AUC_PR) 64434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower } 64534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower for threshold in self._thresholds: 64634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower accuracy_key = keys.ACCURACY_AT_THRESHOLD % threshold 647abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower metric_ops[_summary_key(self._name, 6480fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower accuracy_key)] = _accuracy_at_threshold( 6490fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower labels=labels, 6500fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower predictions=logistic, 6510fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower weights=weights, 6520fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower threshold=threshold, 6530fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower name=accuracy_key) 65434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower # Precision for positive examples. 65534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower precision_key = keys.PRECISION_AT_THRESHOLD % threshold 656abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower metric_ops[_summary_key(self._name, 6570fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower precision_key)] = _precision_at_threshold( 6580fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower labels=labels, 6590fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower predictions=logistic, 6600fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower weights=weights, 6610fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower threshold=threshold, 6620fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower name=precision_key) 66334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower # Recall for positive examples. 66434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower recall_key = keys.RECALL_AT_THRESHOLD % threshold 665abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower metric_ops[_summary_key(self._name, 6660fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower recall_key)] = _recall_at_threshold( 6670fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower labels=labels, 6680fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower predictions=logistic, 6690fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower weights=weights, 6700fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower threshold=threshold, 6710fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower name=recall_key) 67234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower return metric_ops 67334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 6748d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower def create_loss(self, features, mode, logits, labels): 6758d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower """See `Head`.""" 6768d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower del mode, features # Unused for this head. 6778d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower labels = _check_labels(_maybe_expand_dim(labels), self.logits_dimension) 6788d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower if self._label_vocabulary is not None: 6798d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower labels = lookup_ops.index_table_from_tensor( 6808d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower vocabulary_list=tuple(self._label_vocabulary), 6818d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower name='class_id_lookup').lookup(labels) 6828d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower labels = math_ops.to_float(labels) 6838d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower labels = _assert_range(labels, 2) 6848d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower return LossAndLabels( 6858d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower unweighted_loss=nn.sigmoid_cross_entropy_with_logits( 6868d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower labels=labels, logits=logits), 6878d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower processed_labels=labels) 6888d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower 68934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower def create_estimator_spec( 69034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower self, features, mode, logits, labels=None, train_op_fn=None): 69134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower """See `Head`.""" 692169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir # Predict. 693169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir with ops.name_scope('head'): 694169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir with ops.name_scope(None, 'predictions', (logits,)): 695169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir pred_keys = prediction_keys.PredictionKeys 696169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir logits = _check_logits(logits, self.logits_dimension) 697169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir logistic = math_ops.sigmoid(logits, name=pred_keys.LOGISTIC) 698169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir two_class_logits = array_ops.concat( 699169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir (array_ops.zeros_like(logits), logits), 1, name='two_class_logits') 700169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir scores = nn.softmax(two_class_logits, name=pred_keys.PROBABILITIES) 701169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir class_ids = array_ops.reshape( 702169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir math_ops.argmax(two_class_logits, axis=1), (-1, 1), name='classes') 703169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir if self._label_vocabulary: 704169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir table = lookup_ops.index_to_string_table_from_tensor( 705169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir vocabulary_list=self._label_vocabulary, 706169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir name='class_string_lookup') 707169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir classes = table.lookup(class_ids) 708169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir else: 709169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir classes = string_ops.as_string(class_ids, name='str_classes') 710169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir predictions = { 711169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir pred_keys.LOGITS: logits, 712169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir pred_keys.LOGISTIC: logistic, 713169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir pred_keys.PROBABILITIES: scores, 714169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir pred_keys.CLASS_IDS: class_ids, 715169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir pred_keys.CLASSES: classes, 716169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir } 71734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower if mode == model_fn.ModeKeys.PREDICT: 71802ebc44be9a660a54793973110ae26cf948ffceaMustafa Ispir batch_size = array_ops.shape(logistic)[0] 71902ebc44be9a660a54793973110ae26cf948ffceaMustafa Ispir export_class_list = self._label_vocabulary 72002ebc44be9a660a54793973110ae26cf948ffceaMustafa Ispir if not export_class_list: 72102ebc44be9a660a54793973110ae26cf948ffceaMustafa Ispir export_class_list = string_ops.as_string([0, 1]) 72202ebc44be9a660a54793973110ae26cf948ffceaMustafa Ispir export_output_classes = array_ops.tile( 72302ebc44be9a660a54793973110ae26cf948ffceaMustafa Ispir input=array_ops.expand_dims(input=export_class_list, axis=0), 72402ebc44be9a660a54793973110ae26cf948ffceaMustafa Ispir multiples=[batch_size, 1]) 72502ebc44be9a660a54793973110ae26cf948ffceaMustafa Ispir classifier_output = export_output.ClassificationOutput( 72602ebc44be9a660a54793973110ae26cf948ffceaMustafa Ispir scores=scores, 72702ebc44be9a660a54793973110ae26cf948ffceaMustafa Ispir # `ClassificationOutput` requires string classes. 72802ebc44be9a660a54793973110ae26cf948ffceaMustafa Ispir classes=export_output_classes) 72934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower return model_fn.EstimatorSpec( 73034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower mode=model_fn.ModeKeys.PREDICT, 73134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower predictions=predictions, 732edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir export_outputs={ 7334db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel _DEFAULT_SERVING_KEY: classifier_output, 7344db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel _CLASSIFY_SERVING_KEY: classifier_output, 7354db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel _REGRESS_SERVING_KEY: export_output.RegressionOutput( 7364db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel value=logistic), 7374db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel _PREDICT_SERVING_KEY: export_output.PredictOutput(predictions) 738edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir }) 73934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 74034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower # Eval. 7418d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower unweighted_loss, processed_labels = self.create_loss( 7428d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower features=features, mode=mode, logits=logits, labels=labels) 743d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispir weights = _weights(features, self._weight_column) 74434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower training_loss = losses.compute_weighted_loss( 74534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower unweighted_loss, weights=weights, reduction=losses.Reduction.SUM) 74634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower if mode == model_fn.ModeKeys.EVAL: 74734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower return model_fn.EstimatorSpec( 74834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower mode=model_fn.ModeKeys.EVAL, 74934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower predictions=predictions, 75034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower loss=training_loss, 75134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower eval_metric_ops=self._eval_metric_ops( 7528d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower labels=processed_labels, 75334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower logits=logits, 75434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower logistic=logistic, 75534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower scores=scores, 756edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir class_ids=class_ids, 75734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower unweighted_loss=unweighted_loss, 75834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower weights=weights)) 75934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 76034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower # Train. 76134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower if train_op_fn is None: 76234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower raise ValueError('train_op_fn can not be None.') 763169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir with ops.name_scope(''): 7640fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower summary.scalar( 765abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower _summary_key(self._name, metric_keys.MetricKeys.LOSS), 7660fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower training_loss) 7670fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower summary.scalar( 768abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower _summary_key(self._name, metric_keys.MetricKeys.LOSS_MEAN), 7690fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower losses.compute_weighted_loss( 7700fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower unweighted_loss, weights=weights, 7710fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower reduction=losses.Reduction.MEAN)) 772169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir return model_fn.EstimatorSpec( 773169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir mode=model_fn.ModeKeys.TRAIN, 774169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir predictions=predictions, 775169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir loss=training_loss, 776169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir train_op=train_op_fn(training_loss)) 77734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 77834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 779d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispirdef _regression_head_with_mean_squared_error_loss(weight_column=None, 7800fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower label_dimension=1, 781abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower name=None): 78234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower """Creates a `_Head` for regression using the mean squared loss. 78334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 78434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower Args: 785d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispir weight_column: A string or a `_NumericColumn` created by 786d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispir `tf.feature_column.numeric_column` defining feature column representing 78734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower weights. It is used to down weight or boost examples during training. It 78834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower will be multiplied by the loss of the example. 78934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower label_dimension: Number of regression labels per example. This is the size 79034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower of the last dimension of the labels `Tensor` (typically, this has shape 79134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower `[batch_size, label_dimension]`). 792abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower name: name of the head. If provided, summary and metrics keys will be 793abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower suffixed by `"/" + name`. 79434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 79534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower Returns: 79634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower An instance of `_Head` for linear regression. 79734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower """ 79834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower return _RegressionHeadWithMeanSquaredErrorLoss( 7990fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower weight_column=weight_column, 8000fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower label_dimension=label_dimension, 801abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower name=name) 80234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 80334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 80434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerclass _RegressionHeadWithMeanSquaredErrorLoss(_Head): 80534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower """`Head` for regression using the mean squared loss.""" 80634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 807abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower def __init__(self, label_dimension, weight_column=None, name=None): 808d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispir """`Head` for regression.""" 80934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower if label_dimension < 1: 81034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower raise ValueError('Invalid label_dimension %s.' % label_dimension) 81134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower self._logits_dimension = label_dimension 812d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispir self._weight_column = weight_column 813abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower self._name = name 814abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower 815abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower @property 816abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower def name(self): 817abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower return self._name 81834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 81934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower @property 82034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower def logits_dimension(self): 82134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower return self._logits_dimension 82234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 8238d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower def create_loss(self, features, mode, logits, labels): 8248d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower """See `Head`.""" 8258d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower del mode, features # Unused for this head. 8268d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower labels = _check_labels( 8278d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower _maybe_expand_dim(math_ops.to_float(labels)), self._logits_dimension) 8288d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower return LossAndLabels( 8298d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower unweighted_loss=losses.mean_squared_error( 8308d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower labels=labels, predictions=logits, reduction=losses.Reduction.NONE), 8318d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower processed_labels=labels) 8328d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower 83334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower def create_estimator_spec( 83434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower self, features, mode, logits, labels=None, train_op_fn=None): 83534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower """See `Head`.""" 836169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir # Predict. 837169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir with ops.name_scope('head'): 83834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower logits = _check_logits(logits, self._logits_dimension) 83934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower predictions = {prediction_keys.PredictionKeys.PREDICTIONS: logits} 84034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower if mode == model_fn.ModeKeys.PREDICT: 8414db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel regression_output = export_output.RegressionOutput(value=logits) 84234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower return model_fn.EstimatorSpec( 84334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower mode=model_fn.ModeKeys.PREDICT, 84434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower predictions=predictions, 8454db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel export_outputs={ 8464db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel _DEFAULT_SERVING_KEY: regression_output, 8474db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel _REGRESS_SERVING_KEY: regression_output, 8484db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel _PREDICT_SERVING_KEY: export_output.PredictOutput(predictions) 8494db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel }) 85034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 85134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower # Eval. 8528d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower unweighted_loss, _ = self.create_loss( 8538d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower features=features, mode=mode, logits=logits, labels=labels) 854d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispir weights = _weights(features, self._weight_column) 85534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower training_loss = losses.compute_weighted_loss( 85634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower unweighted_loss, weights=weights, reduction=losses.Reduction.SUM) 85734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower if mode == model_fn.ModeKeys.EVAL: 85834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower # Estimator already adds a metric for loss. 85934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower eval_metric_ops = { 86034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower metric_keys.MetricKeys.LOSS_MEAN: metrics_lib.mean( 86134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower unweighted_loss, weights=weights) 86234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower } 86334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower return model_fn.EstimatorSpec( 86434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower mode=model_fn.ModeKeys.EVAL, 86534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower predictions=predictions, 86634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower loss=training_loss, 86734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower eval_metric_ops=eval_metric_ops) 86834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 86934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower # Train. 87034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower if train_op_fn is None: 87134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower raise ValueError('train_op_fn can not be None.') 872169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir with ops.name_scope(''): 8730fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower summary.scalar( 874abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower _summary_key(self._name, metric_keys.MetricKeys.LOSS), 8750fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower training_loss) 8760fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower summary.scalar( 877abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower _summary_key(self._name, metric_keys.MetricKeys.LOSS_MEAN), 8780fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower losses.compute_weighted_loss( 8790fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower unweighted_loss, weights=weights, 8800fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower reduction=losses.Reduction.MEAN)) 881169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir return model_fn.EstimatorSpec( 882169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir mode=model_fn.ModeKeys.TRAIN, 883169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir predictions=predictions, 884169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir loss=training_loss, 885169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir train_op=train_op_fn(training_loss)) 886edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir 887edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir 888edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispirdef _assert_range(labels, n_classes): 889169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir with ops.name_scope(None, 'assert_range', (labels,)): 890169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir assert_less = check_ops.assert_less( 891169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir labels, 892169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir ops.convert_to_tensor(n_classes, dtype=labels.dtype), 893169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir message='Label IDs must < n_classes') 894169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir assert_greater = check_ops.assert_non_negative( 895169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir labels, message='Label IDs must >= 0') 896169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir with ops.control_dependencies((assert_less, assert_greater)): 897169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir return array_ops.identity(labels) 898d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispir 899d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispir 900d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispirdef _weights(features, weight_column): 901d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispir """Fetches weights from features.""" 902169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir with ops.name_scope(None, 'weights', values=features.values()): 903169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir if weight_column is None: 904169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir return 1. 905169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir if isinstance(weight_column, six.string_types): 906169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir weight_column = feature_column_lib.numeric_column(key=weight_column) 907169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir if not isinstance(weight_column, feature_column_lib._NumericColumn): # pylint: disable=protected-access 908169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir raise TypeError('Weight column must be either a string or _NumericColumn.' 909169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir ' Given type: {}.'.format(type(weight_column))) 910169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir weights = weight_column._get_dense_tensor( # pylint: disable=protected-access 911169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir feature_column_lib._LazyBuilder(features)) # pylint: disable=protected-access 912169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir if not (weights.dtype.is_floating or weights.dtype.is_integer): 913169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir raise ValueError('Weight column should be castable to float. ' 914169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir 'Given dtype: {}'.format(weights.dtype)) 915169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir weights = _maybe_expand_dim(math_ops.to_float(weights, name='weights')) 916169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir return weights 917