134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower# 334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower# Licensed under the Apache License, Version 2.0 (the "License"); 434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower# you may not use this file except in compliance with the License. 534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower# You may obtain a copy of the License at 634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower# 734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower# http://www.apache.org/licenses/LICENSE-2.0 834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower# 934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower# Unless required by applicable law or agreed to in writing, software 1034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower# distributed under the License is distributed on an "AS IS" BASIS, 1134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 1234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower# See the License for the specific language governing permissions and 1334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower# limitations under the License. 1434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower# ============================================================================== 1534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower"""Abstractions for the head(s) of a model.""" 1634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 1734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerfrom __future__ import absolute_import 1834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerfrom __future__ import division 1934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerfrom __future__ import print_function 2034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 2134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerimport abc 22e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Caiimport collections 2334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 2434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerimport six 2534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 2634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerfrom tensorflow.python.estimator import model_fn 273942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlowerfrom tensorflow.python.estimator import util 2834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerfrom tensorflow.python.estimator.canned import metric_keys 2934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerfrom tensorflow.python.estimator.canned import prediction_keys 3034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerfrom tensorflow.python.estimator.export import export_output 31d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispirfrom tensorflow.python.feature_column import feature_column as feature_column_lib 3234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerfrom tensorflow.python.framework import dtypes 3334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerfrom tensorflow.python.framework import ops 3434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerfrom tensorflow.python.framework import sparse_tensor 3534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerfrom tensorflow.python.ops import array_ops 3634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerfrom tensorflow.python.ops import check_ops 3767c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlowerfrom tensorflow.python.ops import control_flow_ops 38ce21f5e372360451528c9716a743b65308421423Mustafa Ispirfrom tensorflow.python.ops import lookup_ops 3934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerfrom tensorflow.python.ops import math_ops 4034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerfrom tensorflow.python.ops import metrics as metrics_lib 4134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerfrom tensorflow.python.ops import nn 4234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerfrom tensorflow.python.ops import string_ops 4334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerfrom tensorflow.python.ops import weights_broadcast_ops 4434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerfrom tensorflow.python.ops.losses import losses 4502ebc44be9a660a54793973110ae26cf948ffceaMustafa Ispirfrom tensorflow.python.saved_model import signature_constants 46169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispirfrom tensorflow.python.summary import summary 4702ebc44be9a660a54793973110ae26cf948ffceaMustafa Ispir 4802ebc44be9a660a54793973110ae26cf948ffceaMustafa Ispir_DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY 4934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 504db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel# The above default is defined by TF Serving, but these next three are just 514db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel# a local convention without any special meaning. 524db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel_CLASSIFY_SERVING_KEY = 'classification' 534db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel_REGRESS_SERVING_KEY = 'regression' 544db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel_PREDICT_SERVING_KEY = 'predict' 554db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel 5634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 5799d51b8da87ba462f66a6be90212677f2cae9e32A. Unique TensorFlower# A LossSpec contains 5820d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower# * a scalar `Tensor` representing reduced weighted training loss 5920d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower# * a scalar `Tensor` representing the unreduced unweighted loss 6020d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower# * a scalar `Tensor` representing the example weights 6199d51b8da87ba462f66a6be90212677f2cae9e32A. Unique TensorFlower# * possibly processed labels (e.g. vocabulary lookup, shape manipulation, etc) 6299d51b8da87ba462f66a6be90212677f2cae9e32A. Unique TensorFlowerLossSpec = collections.namedtuple( 6320d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower 'LossSpec', ['training_loss', 'unreduced_loss', 'weights', 6420d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower 'processed_labels']) 658d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower 668d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower 670fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlowerdef _summary_key(head_name, val): 680fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower return '%s/%s' % (val, head_name) if head_name else val 690fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower 700fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower 7134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerclass _Head(object): 7234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower """Interface for the head/top of a model. 7334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 7434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower Given logits (or output of a hidden layer), a Head knows how to compute 7534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower predictions, loss, train_op, metrics and export outputs. It is meant to: 7634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 7734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 1. Simplify writing model_fn and to make model_fn more configurable 7834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 2. Support wide range of machine learning models. Since most heads can work 7934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower with logits, they can support DNN, RNN, Wide, Wide&Deep, 8034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower Global objectives, Gradient boosted trees and many other types 8134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower of machine learning models. 8234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 8334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower Common usage: 8434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower Here is simplified model_fn to build a DNN regression model. 8534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower ```python 8634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower def _my_dnn_model_fn(features, labels, mode, params, config=None): 8734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower # Optionally your callers can pass head to model_fn as a param. 8834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower head = tf.contrib.learn.regression_head(...) 8934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower input = tf.contrib.layers.input_from_feature_columns(features, ...) 9034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower last_hidden_layer_out = tf.contrib.layers.stack( 9134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower input, tf.contrib.layers.fully_connected, [1000, 500]) 9234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower logits = tf.contrib.layers.fully_connected( 9334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower last_hidden_layer_out, head.logits_dimension, activation_fn=None) 9434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 9534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower def _train_op_fn(loss): 9634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower return optimizer.minimize(loss) 9734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 9834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower return head.create_estimator_spec( 9934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower features=features, 10034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower labels=labels, 10134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower mode=mode, 10234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower logits=logits, 10334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower train_op_fn=_train_op_fn) 10434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower ``` 10534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 10634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower There are cases where computing and applying gradients can not be meaningfully 10734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower captured with train_op_fn we support (for example, with sync optimizer). In 10834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower such case, you can take the responsibility on your own. Here is a common 10934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower use case, 11034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower ```python 11134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower estimator_spec = head.create_estimator_spec( 11234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower features=features, 11334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower labels=labels, 11434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower mode=mode, 11534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower logits=logits, 11634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower train_op_fn=tf.contrib.learn.no_op_train_fn) 11734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower if mode == model_fn.ModeKeys.TRAIN: 11834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower optimizer = ... 11934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower sync = tf.train.SyncReplicasOptimizer(opt=optimizer, ...) 12034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower update_op = tf.contrib.layers.optimize_loss(optimizer=sync, 12134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower loss=estimator_spec.loss, ...) 12234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower hooks = [sync.make_session_run_hook(is_chief)] 123b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng ... update train_op and hooks in EstimatorSpec and return 12434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower ``` 12534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower """ 12634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower __metaclass__ = abc.ABCMeta 12734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 12834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower @abc.abstractproperty 129abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower def name(self): 130abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower """The name of this head. 131abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower 132abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower Returns: 133abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower A string. 134abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower """ 135abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower raise NotImplementedError('Calling an abstract method.') 136abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower 137abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower @abc.abstractproperty 13834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower def logits_dimension(self): 13934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower """Size of the last dimension of the logits `Tensor`. 14034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 14134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower Typically, logits is of shape `[batch_size, logits_dimension]`. 14234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 14334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower Returns: 14434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower The expected size of the `logits` tensor. 14534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower """ 14634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower raise NotImplementedError('Calling an abstract method.') 14734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 14834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower @abc.abstractmethod 1498d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower def create_loss(self, features, mode, logits, labels): 1508d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower """Returns a loss Tensor from provided logits. 1518d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower 1528d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower This function is designed to be used by framework developers. Almost all 1538d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower users should use create_estimator_spec(), which calls this internally. 1548d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower `mode` and `features` are most likely not used, but some Head 1558d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower implementations may require them. 1568d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower 1578d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower Args: 1588d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower features: Input `dict` of `Tensor` objects. 1598d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower mode: Estimator's `ModeKeys`. 1608d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower logits: logits `Tensor` to be used for loss construction. 1612dab9fd3c89f47dbb0b5f4368084cebb56e03a09A. Unique TensorFlower labels: Labels `Tensor`, or `dict` of same. 1628d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower 1638d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower Returns: 16499d51b8da87ba462f66a6be90212677f2cae9e32A. Unique TensorFlower A LossSpec that contains 16520d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower * the scalar `Tensor` representing reduced weighted training loss 16620d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower * the scalar `Tensor` representing the unreduced unweighted loss 16720d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower * the scalar `Tensor` representing the example weights 16899d51b8da87ba462f66a6be90212677f2cae9e32A. Unique TensorFlower * possibly processed labels (e.g. vocabulary lookup, shape manipulation, 16999d51b8da87ba462f66a6be90212677f2cae9e32A. Unique TensorFlower etc.) 17099d51b8da87ba462f66a6be90212677f2cae9e32A. Unique TensorFlower 17199d51b8da87ba462f66a6be90212677f2cae9e32A. Unique TensorFlower To be extendable in the future. 1728d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower """ 1738d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower raise NotImplementedError('Calling an abstract method.') 1748d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower 1758d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower @abc.abstractmethod 17634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower def create_estimator_spec( 177c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower self, features, mode, logits, labels=None, train_op_fn=None, 178c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower regularization_losses=None): 17934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower """Returns `EstimatorSpec` that a model_fn can return. 18034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 18134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower Please note that, 18234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower + All args must be passed via name. 18334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 18434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower Args: 18567c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower features: Input `dict` of `Tensor` or `SparseTensor` objects. 18634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower mode: Estimator's `ModeKeys`. 18734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower logits: logits `Tensor` to be used by the head. 18834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower labels: Labels `Tensor`, or `dict` of same. 18934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower train_op_fn: Function that takes a scalar loss `Tensor` and returns an op 190c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower to optimize the model with the loss. This is used in TRAIN mode and 191c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower must not be None. None is allowed in other modes. If you want to 192c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower optimize loss yourself you can pass `no_op_train_fn` and then use 193c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower EstimatorSpec.loss to compute and apply gradients. 194c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower regularization_losses: A list of additional scalar losses to be added to 195c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower the training loss, such as regularization losses. 19634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 19734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower Returns: 19834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower `EstimatorSpec`. 19934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower """ 20034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower raise NotImplementedError('Calling an abstract method.') 20134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 20234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 20367c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlowerdef _check_dense_labels_match_logits_and_reshape( 20467c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower labels, logits, expected_labels_dimension): 20567c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower """Checks that labels shape matches logits and reshapes if needed. 20667c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower 20767c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower Consider logits of shape [D0, D1, ... DN, logits_dimension]. Then labels 20867c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower shape must be [D0, D1, ... DN, expected_labels_dimension]. 20967c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower If expected_labels_dimension=1, labels could be [D0, D1, ... DN] and this 21067c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower method reshapes them to [D0, D1, ... DN, 1]. 21167c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower 21267c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower Args: 21367c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower labels: labels Tensor. 21467c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower logits: logits Tensor. 21567c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower expected_labels_dimension: Integer. 21667c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower Returns: 21767c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower Validated and reshaped labels Tensor. 21867c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower Raises: 21967c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower ValueError: If labels is a SparseTensor. 22067c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower ValueError: If labels shape is statically defined and fails validation. 22167c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower OpError: If labels shape is not statically defined and fails validation. 22267c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower """ 22367c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower if labels is None: 22467c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower raise ValueError( 22567c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower 'You must provide a labels Tensor. Given: None. ' 22667c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower 'Suggested troubleshooting steps: Check that your data contain ' 22767c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower 'your label feature. Check that your input_fn properly parses and ' 22867c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower 'returns labels.') 22967c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower with ops.name_scope(None, 'labels', (labels, logits)) as scope: 23067c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower labels = sparse_tensor.convert_to_tensor_or_sparse_tensor(labels) 23167c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower if isinstance(labels, sparse_tensor.SparseTensor): 23267c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower raise ValueError( 23367c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower 'SparseTensor labels are not supported. ' 23467c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower 'labels must be a Tensor of shape [D0, D1, ..., DN, %s], ' 23567c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower 'e.g. [batch_size, %s]. ' 23667c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower 'Suggested Fix (1): Check the label feature in your data. ' 23767c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower 'Each example must contain %s value(s). If not, your choice of label ' 23867c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower 'was probably incorrect. ' 23967c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower 'Suggested Fix (2): In your input_fn, use ' 24067c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower 'tf.sparse_tensor_to_dense() to turn labels into a Tensor.' 24167c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower '' % (expected_labels_dimension, expected_labels_dimension, 24267c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower expected_labels_dimension)) 24367c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower if (labels.shape.ndims is not None and logits.shape.ndims is not None and 24467c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower labels.shape.ndims == logits.shape.ndims - 1): 24567c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower labels = array_ops.expand_dims(labels, -1) 24667c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower labels_shape = array_ops.shape(labels) 24767c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower logits_shape = array_ops.shape(logits) 24867c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower err_msg = ( 24967c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower 'labels shape must be [D0, D1, ... DN, {}]. ' 25067c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower 'Suggested Fix: check your n_classes argument to the estimator ' 25167c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower 'and/or the shape of your label.'.format(expected_labels_dimension)) 25267c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower assert_rank = check_ops.assert_rank_at_least(labels, 2, message=err_msg) 25367c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower with ops.control_dependencies([assert_rank]): 25467c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower static_shape = labels.shape 25567c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower if static_shape.ndims is not None: 25667c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower dim1 = static_shape[-1] 25767c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower if (dim1 is not None) and (dim1 != expected_labels_dimension): 25867c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower raise ValueError( 25967c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower 'Mismatched label shape. ' 26067c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower 'Classifier configured with n_classes=%s. Received %s. ' 26167c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower 'Suggested Fix: check your n_classes argument to the estimator ' 26267c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower 'and/or the shape of your label.' % 26367c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower (expected_labels_dimension, dim1)) 26467c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower expected_labels_shape = array_ops.concat( 26567c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower [logits_shape[:-1], [expected_labels_dimension]], axis=0) 26667c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower assert_dimension = check_ops.assert_equal( 26767c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower expected_labels_shape, labels_shape, message=err_msg, 26867c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower data=['expected_labels_shape: ', expected_labels_shape, 26967c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower 'labels_shape: ', labels_shape]) 27067c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower with ops.control_dependencies([assert_dimension]): 27167c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower return array_ops.identity(labels, name=scope) 27267c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower 27367c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower 274d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlowerdef _get_weights_and_check_match_logits( 275d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower features, weight_column, logits, allow_per_logit_weights=False): 276d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower """Fetches weights from features and checks that the shape matches logits. 27767c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower 27867c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower Consider logits of shape [D0, D1, ... DN, logits_dimension]. Weights shape 27967c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower can be either: 280d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower * [D0, D1, ... DN, logits_dimension] if `allow_per_logit_weights=True`. 2817948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower * [D0, D1, ... DN, 1] 28267c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower * [D0, D1, ... DN]: In this case, weights is reshaped into 28367c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower [D0, D1, ... DN, 1] to work with weight broadcasting rules. 28467c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower 28567c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower Args: 286d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower features: The features dict that contains weights. 287d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower weight_column: The weight column. If not given, this method returns 1. 28867c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower logits: logits Tensor. 289d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower allow_per_logit_weights: Boolean. Whether we allow weights along the logits 290d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower dimension, namely shape `[D0, D1, ... DN, logits_dimension]`. 29167c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower Returns: 29267c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower Validated and reshaped weights Tensor. 293d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower Raises: 294d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower ValueError: If the weights `Tensor` cannot be cast into float. 29567c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower """ 296d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower if allow_per_logit_weights: 297d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower err_msg = ( 298d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower 'weights shape must be [D0, D1, ... DN], [D0, D1, ... DN, 1] or ' 299d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower '[D0, D1, ... DN, logits_dimension]') 300d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower else: 301d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower err_msg = ( 302d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower 'weights shape must be [D0, D1, ... DN] or [D0, D1, ... DN, 1]') 303d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower with ops.name_scope( 304d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower None, 'weights', 305d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower values=tuple(six.itervalues(features)) + (logits,)) as scope: 306d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower # Fetch the weights. 307d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower if weight_column is None: 308d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower return 1. 309d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower if isinstance(weight_column, six.string_types): 310d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower weight_column = feature_column_lib.numeric_column( 311d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower key=weight_column, shape=(1,)) 312d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower if not isinstance(weight_column, feature_column_lib._NumericColumn): # pylint: disable=protected-access 313d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower raise TypeError('Weight column must be either a string or _NumericColumn.' 314d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower ' Given type: {}.'.format(type(weight_column))) 315d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower weights = weight_column._get_dense_tensor( # pylint: disable=protected-access 316d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower feature_column_lib._LazyBuilder(features)) # pylint: disable=protected-access 317d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower if not (weights.dtype.is_floating or weights.dtype.is_integer): 318d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower raise ValueError('Weight column should be castable to float. ' 319d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower 'Given dtype: {}'.format(weights.dtype)) 320d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower weights = math_ops.to_float(weights, name='weights') 321d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower 322d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower # Validate the weights shape. 32367c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower weights_shape = array_ops.shape(weights, name='weights_shape') 32467c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower logits_shape = array_ops.shape(logits, name='logits_shape') 32567c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower if (weights.shape.ndims is not None and logits.shape.ndims is not None and 32667c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower weights.shape.ndims == logits.shape.ndims - 1): 32767c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower assert_dimension = check_ops.assert_equal( 32867c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower logits_shape[:-1], weights_shape, message=err_msg, 32967c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower data=['logits_shape: ', logits_shape, 33067c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower 'weights_shape: ', weights_shape]) 33167c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower with ops.control_dependencies([assert_dimension]): 33267c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower return array_ops.expand_dims(weights, -1, name=scope) 33367c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower supported_weights_shape = array_ops.concat([logits_shape[:-1], [1]], axis=0) 334d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower if allow_per_logit_weights: 335d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower condition = math_ops.reduce_any( 336d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower [math_ops.reduce_all(math_ops.equal(logits_shape, weights_shape)), 337d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower math_ops.reduce_all(math_ops.equal( 338d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower supported_weights_shape, weights_shape))]) 339d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower assert_dimension = control_flow_ops.Assert( 340d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower condition=condition, 341d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower data=[err_msg, 'logits_shape: ', logits_shape, 342d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower 'weights_shape: ', weights_shape]) 343d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower else: 344d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower assert_dimension = check_ops.assert_equal( 345d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower supported_weights_shape, weights_shape, message=err_msg, 346d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower data=['logits_shape: ', logits_shape, 347d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower 'weights_shape: ', weights_shape]) 34867c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower with ops.control_dependencies([assert_dimension]): 34967c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower return array_ops.identity(weights, name=scope) 35067c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower 35167c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower 35267c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlowerdef _check_logits_final_dim(logits, expected_logits_dimension): 35367c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower """Checks that logits shape is [D0, D1, ... DN, logits_dimension].""" 35467c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower with ops.name_scope(None, 'logits', (logits,)) as scope: 35567c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower logits = math_ops.to_float(logits) 35667c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower logits_shape = array_ops.shape(logits) 35767c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower assert_rank = check_ops.assert_rank_at_least( 35867c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower logits, 2, data=[logits_shape], 35967c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower message='logits shape must be [D0, D1, ... DN, logits_dimension]') 36067c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower with ops.control_dependencies([assert_rank]): 36167c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower static_shape = logits.shape 36267c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower if static_shape.ndims is not None and static_shape[-1] is not None: 36367c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower if static_shape[-1] != expected_logits_dimension: 36467c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower raise ValueError( 36567c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower 'logits shape must be [D0, D1, ... DN, logits_dimension], ' 36667c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower 'got %s.' % (static_shape,)) 36767c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower return logits 36867c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower assert_dimension = check_ops.assert_equal( 36967c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower expected_logits_dimension, logits_shape[-1], data=[logits_shape], 37067c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower message='logits shape must be [D0, D1, ... DN, logits_dimension]') 37167c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower with ops.control_dependencies([assert_dimension]): 37267c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower return array_ops.identity(logits, name=scope) 37367c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower 37467c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower 3753942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlowerdef _validate_loss_fn_args(loss_fn): 3763942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower """Validates loss_fn arguments. 3773942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower 3783942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower Required arguments: labels, logits. 3793942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower Optional arguments: features. 3803942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower 3813942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower Args: 3823942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower loss_fn: The loss function. 3833942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower Raises: 3843942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower ValueError: If the signature is unexpected. 3853942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower """ 3863942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower loss_fn_args = util.fn_args(loss_fn) 3873942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower for required_arg in ['labels', 'logits']: 3883942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower if required_arg not in loss_fn_args: 3893942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower raise ValueError( 3903942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower 'loss_fn must contain argument: {}. ' 3913942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower 'Given arguments: {}'.format(required_arg, loss_fn_args)) 3923942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower invalid_args = list(set(loss_fn_args) - set(['labels', 'logits', 'features'])) 3933942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower if invalid_args: 3943942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower raise ValueError('loss_fn has unexpected args: {}'.format(invalid_args)) 3953942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower 3963942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower 3973942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlowerdef _call_loss_fn(loss_fn, labels, logits, features, expected_loss_dim=1): 3983942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower """Calls loss_fn and checks the returned shape. 3993942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower 4003942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower Args: 4013942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower loss_fn: The loss function. 4023942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower labels: Processed labels Tensor. 4033942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower logits: Logits Tensor of shape [D0, D1, ... DN, logits_dimension]. 4043942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower features: Features dict. 4053942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower expected_loss_dim: The expected last dimension of loss Tensor. 4063942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower Returns: 4073942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower Loss Tensor with shape [D0, D1, ... DN, expected_loss_dim]. 4083942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower """ 4093942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower loss_fn_args = util.fn_args(loss_fn) 4103942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower kwargs = {} 4113942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower if 'features' in loss_fn_args: 4123942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower kwargs['features'] = features 4133942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower with ops.name_scope( 4143942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower None, 'call_loss_fn', 4153942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower values=[labels, logits] + list(six.itervalues(features))): 4163942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower unweighted_loss = loss_fn(labels=labels, logits=logits, **kwargs) 4173942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower logits_shape = array_ops.shape(logits, name='logits_shape') 4183942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower expected_loss_shape = array_ops.concat( 4193942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower [logits_shape[:-1], [expected_loss_dim]], axis=0, 4203942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower name='expected_loss_shape') 4213942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower loss_shape = array_ops.shape(unweighted_loss, name='loss_shape') 4223942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower check_loss_shape_op = control_flow_ops.Assert( 4233942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower math_ops.reduce_all(math_ops.equal(loss_shape, expected_loss_shape)), 4243942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower data=[ 4253942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower 'loss_fn must return Tensor of shape ' 4263942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower '[D0, D1, ... DN, {}]. '.format(expected_loss_dim), 4273942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower 'logits_shape: ', logits_shape, 'loss_shape: ', loss_shape], 4283942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower name='check_loss_shape') 4293942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower with ops.control_dependencies([check_loss_shape_op]): 4303942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower return array_ops.identity(unweighted_loss) 4313942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower 4323942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower 43334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerdef _indicator_labels_mean(labels, weights=None, name=None): 43434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower with ops.name_scope(name, 'labels_mean', (labels, weights)) as scope: 43534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower labels = math_ops.to_float(labels, name='labels') 43634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower if weights is not None: 43734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower weights = weights_broadcast_ops.broadcast_weights(weights, labels) 43834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower return metrics_lib.mean(labels, weights=weights, name=scope) 43934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 44034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 441df22cf83a21b62838ecf6f3a1c8a9c30ab20d482A. Unique TensorFlowerdef _classification_output(scores, n_classes, label_vocabulary=None): 442df22cf83a21b62838ecf6f3a1c8a9c30ab20d482A. Unique TensorFlower batch_size = array_ops.shape(scores)[0] 443df22cf83a21b62838ecf6f3a1c8a9c30ab20d482A. Unique TensorFlower if label_vocabulary: 444df22cf83a21b62838ecf6f3a1c8a9c30ab20d482A. Unique TensorFlower export_class_list = label_vocabulary 445df22cf83a21b62838ecf6f3a1c8a9c30ab20d482A. Unique TensorFlower else: 446df22cf83a21b62838ecf6f3a1c8a9c30ab20d482A. Unique TensorFlower export_class_list = string_ops.as_string(math_ops.range(n_classes)) 447df22cf83a21b62838ecf6f3a1c8a9c30ab20d482A. Unique TensorFlower export_output_classes = array_ops.tile( 448df22cf83a21b62838ecf6f3a1c8a9c30ab20d482A. Unique TensorFlower input=array_ops.expand_dims(input=export_class_list, axis=0), 449df22cf83a21b62838ecf6f3a1c8a9c30ab20d482A. Unique TensorFlower multiples=[batch_size, 1]) 450df22cf83a21b62838ecf6f3a1c8a9c30ab20d482A. Unique TensorFlower return export_output.ClassificationOutput( 451df22cf83a21b62838ecf6f3a1c8a9c30ab20d482A. Unique TensorFlower scores=scores, 452df22cf83a21b62838ecf6f3a1c8a9c30ab20d482A. Unique TensorFlower # `ClassificationOutput` requires string classes. 453df22cf83a21b62838ecf6f3a1c8a9c30ab20d482A. Unique TensorFlower classes=export_output_classes) 454df22cf83a21b62838ecf6f3a1c8a9c30ab20d482A. Unique TensorFlower 455df22cf83a21b62838ecf6f3a1c8a9c30ab20d482A. Unique TensorFlower 45634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerdef _accuracy_baseline(labels_mean): 45734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower """Return accuracy baseline based on labels mean. 45834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 45934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower This is the best the model could do by always predicting one class. 46034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 46134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower Args: 46234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower labels_mean: Tuple of value and update op. 46334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 46434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower Returns: 46534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower Tuple of value and update op. 46634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower """ 46734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower with ops.name_scope(None, 'accuracy_baseline', labels_mean): 46834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower value, update_op = labels_mean 46934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower return ( 47034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower math_ops.maximum(value, 1. - value, name='value'), 47134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower math_ops.maximum(update_op, 1 - update_op, name='update_op')) 47234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 47334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 47434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerdef _predictions_mean(predictions, weights=None, name=None): 47534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower with ops.name_scope( 47634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower name, 'predictions_mean', (predictions, weights)) as scope: 47734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower predictions = math_ops.to_float(predictions, name='predictions') 47834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower if weights is not None: 47934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower weights = weights_broadcast_ops.broadcast_weights(weights, predictions) 48034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower return metrics_lib.mean(predictions, weights=weights, name=scope) 48134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 48234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 48334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerdef _auc(labels, predictions, weights=None, curve='ROC', name=None): 48434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower with ops.name_scope(name, 'auc', (predictions, labels, weights)) as scope: 48534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower predictions = math_ops.to_float(predictions, name='predictions') 48634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower if weights is not None: 48734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower weights = weights_broadcast_ops.broadcast_weights(weights, predictions) 48834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower return metrics_lib.auc( 48934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower labels=labels, predictions=predictions, weights=weights, curve=curve, 49034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower name=scope) 49134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 49234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 49334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerdef _accuracy_at_threshold(labels, predictions, weights, threshold, name=None): 49434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower with ops.name_scope( 49534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower name, 'accuracy_at_%s' % threshold, 49634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower (predictions, labels, weights, threshold)) as scope: 49734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower threshold_predictions = math_ops.to_float( 49834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower math_ops.greater_equal(predictions, threshold)) 49934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower return metrics_lib.accuracy( 50034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower labels=labels, predictions=threshold_predictions, weights=weights, 50134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower name=scope) 50234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 50334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 50434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerdef _precision_at_threshold(labels, predictions, weights, threshold, name=None): 50534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower with ops.name_scope( 50634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower name, 'precision_at_%s' % threshold, 50734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower (predictions, labels, weights, threshold)) as scope: 50834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower precision_tensor, update_op = metrics_lib.precision_at_thresholds( 50934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower labels=labels, predictions=predictions, thresholds=(threshold,), 51034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower weights=weights, name=scope) 51134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower return array_ops.squeeze(precision_tensor), array_ops.squeeze(update_op) 51234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 51334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 51434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerdef _recall_at_threshold(labels, predictions, weights, threshold, name=None): 51534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower with ops.name_scope( 51634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower name, 'recall_at_%s' % threshold, 51734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower (predictions, labels, weights, threshold)) as scope: 51834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower precision_tensor, update_op = metrics_lib.recall_at_thresholds( 51934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower labels=labels, predictions=predictions, thresholds=(threshold,), 52034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower weights=weights, name=scope) 52134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower return array_ops.squeeze(precision_tensor), array_ops.squeeze(update_op) 52234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 52334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 52420d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlowerdef _multi_class_head_with_softmax_cross_entropy_loss( 52520d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower n_classes, 52620d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower weight_column=None, 52720d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower label_vocabulary=None, 52820d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower loss_reduction=losses.Reduction.SUM, 5293942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower loss_fn=None, 53020d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower name=None): 53134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower """Creates a '_Head' for multi class classification. 53234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 5337948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower The head expects `logits` with shape `[D0, D1, ... DN, n_classes]`. 5347948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower In many applications, the shape is `[batch_size, n_classes]`. 5357948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower 5367948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower `labels` must be a dense `Tensor` with shape matching `logits`, namely 5377948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower `[D0, D1, ... DN, 1]`. If `label_vocabulary` given, `labels` must be a string 5387948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower `Tensor` with values from the vocabulary. If `label_vocabulary` is not given, 5397948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower `labels` must be an integer `Tensor` with values specifying the class index. 5407948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower 5417948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower If `weight_column` is specified, weights must be of shape 5427948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower `[D0, D1, ... DN]`, or `[D0, D1, ... DN, 1]`. 5437948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower 5447948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower The loss is the weighted sum over the input dimensions. Namely, if the input 5457948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower labels have shape `[batch_size, 1]`, the loss is the weighted sum over 5467948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower `batch_size`. 54734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 5483942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower Also supports custom `loss_fn`. `loss_fn` takes `(labels, logits)` or 5493942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower `(labels, logits, features)` as arguments and returns unreduced loss with 5503942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower shape `[D0, D1, ... DN, 1]`. `loss_fn` must support integer `labels` with 5513942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower shape `[D0, D1, ... DN, 1]`. Namely, the head applies `label_vocabulary` to 5523942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower the input labels before passing them to `loss_fn`. 5533942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower 55434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower Args: 55534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower n_classes: Number of classes, must be greater than 2 (for 2 classes, use 55634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower `_BinaryLogisticHeadWithSigmoidCrossEntropyLoss`). 557d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispir weight_column: A string or a `_NumericColumn` created by 558d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispir `tf.feature_column.numeric_column` defining feature column representing 55934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower weights. It is used to down weight or boost examples during training. It 56034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower will be multiplied by the loss of the example. 5619c40507f80434058f600ebebb8b9d6971dd0bdb4Petros Mol label_vocabulary: A list or tuple of strings representing possible label 5629c40507f80434058f600ebebb8b9d6971dd0bdb4Petros Mol values. If it is not given, that means labels are already encoded as an 5639c40507f80434058f600ebebb8b9d6971dd0bdb4Petros Mol integer within [0, n_classes). If given, labels must be of string type and 5649c40507f80434058f600ebebb8b9d6971dd0bdb4Petros Mol have any value in `label_vocabulary`. Note that errors will be raised if 5659c40507f80434058f600ebebb8b9d6971dd0bdb4Petros Mol `label_vocabulary` is not provided but labels are strings. 56620d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to 56720d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower reduce training loss over batch. Defaults to `SUM`. 5683942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower loss_fn: Optional loss function. 569abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower name: name of the head. If provided, summary and metrics keys will be 57001c76110eb3cb1c378c9d7a14ca9f838bad6c7d1A. Unique TensorFlower suffixed by `"/" + name`. Also used as `name_scope` when creating ops. 57134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 57234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower Returns: 5730815de21239955e346b562e899640649c8d2b9cbBenoit Steiner An instance of `_Head` for multi class classification. 57434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 57534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower Raises: 57620d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower ValueError: If `n_classes`, `label_vocabulary` or `loss_reduction` is 57720d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower invalid. 57834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower """ 579edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir if label_vocabulary is not None and not isinstance(label_vocabulary, 580edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir (list, tuple)): 5819c40507f80434058f600ebebb8b9d6971dd0bdb4Petros Mol raise ValueError( 5829c40507f80434058f600ebebb8b9d6971dd0bdb4Petros Mol 'label_vocabulary should be a list or a tuple. Given type: {}'.format( 5839c40507f80434058f600ebebb8b9d6971dd0bdb4Petros Mol type(label_vocabulary))) 58420d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower if (loss_reduction not in losses.Reduction.all() or 58520d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower loss_reduction == losses.Reduction.NONE): 58620d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower raise ValueError('Invalid loss_reduction: {}'.format(loss_reduction)) 5873942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower if loss_fn: 5883942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower _validate_loss_fn_args(loss_fn) 58920d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower return _MultiClassHeadWithSoftmaxCrossEntropyLoss( 59020d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower n_classes=n_classes, 59120d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower weight_column=weight_column, 59220d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower label_vocabulary=label_vocabulary, 59320d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower loss_reduction=loss_reduction, 5943942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower loss_fn=loss_fn, 59520d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower name=name) 59634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 59734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 59834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerclass _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head): 59934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower """See `_multi_class_head_with_softmax_cross_entropy_loss`.""" 60034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 6010fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower def __init__(self, 6020fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower n_classes, 6030fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower weight_column=None, 6040fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower label_vocabulary=None, 60520d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower loss_reduction=losses.Reduction.SUM, 6063942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower loss_fn=None, 607abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower name=None): 60834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower if (n_classes is None) or (n_classes <= 2): 60934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower raise ValueError('n_classes must be > 2: %s.' % n_classes) 61034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower self._n_classes = n_classes 611d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispir self._weight_column = weight_column 612ce21f5e372360451528c9716a743b65308421423Mustafa Ispir self._label_vocabulary = label_vocabulary 61320d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower self._loss_reduction = loss_reduction 6143942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower self._loss_fn = loss_fn 615abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower self._name = name 616abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower 617abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower @property 618abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower def name(self): 619abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower return self._name 62034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 62134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower @property 62234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower def logits_dimension(self): 62334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower return self._n_classes 62434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 625c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower def _eval_metric_ops( 626c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower self, labels, class_ids, weights, unreduced_loss, regularization_loss): 62734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower """Returns the Eval metric ops.""" 62834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower with ops.name_scope( 629c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower None, 'metrics', 630c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower (labels, class_ids, weights, unreduced_loss, regularization_loss)): 63134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower keys = metric_keys.MetricKeys 63234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower metric_ops = { 63334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower # Estimator already adds a metric for loss. 63434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower # TODO(xiejw): Any other metrics? 635abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower _summary_key(self._name, keys.LOSS_MEAN): 6360fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower metrics_lib.mean( 63720d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower values=unreduced_loss, 63820d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower weights=weights, 63999d51b8da87ba462f66a6be90212677f2cae9e32A. Unique TensorFlower name=keys.LOSS_MEAN), 640abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower _summary_key(self._name, keys.ACCURACY): 6410fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower metrics_lib.accuracy( 6420fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower labels=labels, 6430fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower predictions=class_ids, 6440fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower weights=weights, 6450fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower name=keys.ACCURACY), 64634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower } 647c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower if regularization_loss is not None: 648c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower metric_ops[_summary_key(self._name, keys.LOSS_REGULARIZATION)] = ( 649c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower metrics_lib.mean( 650c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower values=regularization_loss, 651c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower name=keys.LOSS_REGULARIZATION)) 65234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower return metric_ops 65334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 654ce21f5e372360451528c9716a743b65308421423Mustafa Ispir def _label_ids(self, labels): 655ce21f5e372360451528c9716a743b65308421423Mustafa Ispir """Converts labels to integer id space.""" 656ce21f5e372360451528c9716a743b65308421423Mustafa Ispir if self._label_vocabulary is None: 657ce21f5e372360451528c9716a743b65308421423Mustafa Ispir if not labels.dtype.is_integer: 6589c40507f80434058f600ebebb8b9d6971dd0bdb4Petros Mol raise ValueError('Labels dtype should be integer. Instead got {}.'. 6599c40507f80434058f600ebebb8b9d6971dd0bdb4Petros Mol format(labels.dtype)) 660ce21f5e372360451528c9716a743b65308421423Mustafa Ispir label_ids = labels 661ce21f5e372360451528c9716a743b65308421423Mustafa Ispir else: 662ce21f5e372360451528c9716a743b65308421423Mustafa Ispir if labels.dtype != dtypes.string: 663ce21f5e372360451528c9716a743b65308421423Mustafa Ispir raise ValueError('Labels dtype should be string if there is a ' 664ce21f5e372360451528c9716a743b65308421423Mustafa Ispir 'vocabulary. Instead got {}'.format(labels.dtype)) 665ce21f5e372360451528c9716a743b65308421423Mustafa Ispir label_ids = lookup_ops.index_table_from_tensor( 666ce21f5e372360451528c9716a743b65308421423Mustafa Ispir vocabulary_list=tuple(self._label_vocabulary), 667ce21f5e372360451528c9716a743b65308421423Mustafa Ispir name='class_id_lookup').lookup(labels) 668edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir return _assert_range(label_ids, self._n_classes) 669ce21f5e372360451528c9716a743b65308421423Mustafa Ispir 6708d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower def create_loss(self, features, mode, logits, labels): 6718d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower """See `Head`.""" 67299d51b8da87ba462f66a6be90212677f2cae9e32A. Unique TensorFlower del mode # Unused for this head. 6737948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower logits = ops.convert_to_tensor(logits) 6747948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower labels = _check_dense_labels_match_logits_and_reshape( 6757948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower labels=labels, logits=logits, expected_labels_dimension=1) 6767948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower label_ids = self._label_ids(labels) 6773942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower if self._loss_fn: 6783942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower unweighted_loss = _call_loss_fn( 6793942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower loss_fn=self._loss_fn, labels=label_ids, logits=logits, 6803942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower features=features, expected_loss_dim=1) 6813942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower else: 6823942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower unweighted_loss = losses.sparse_softmax_cross_entropy( 6833942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower labels=label_ids, logits=logits, reduction=losses.Reduction.NONE) 6843942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower # Restore the squeezed dim, so unweighted_loss matches the weights shape. 6853942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower unweighted_loss = array_ops.expand_dims(unweighted_loss, axis=-1) 686d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower weights = _get_weights_and_check_match_logits( 687d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower features=features, weight_column=self._weight_column, logits=logits) 68820d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower training_loss = losses.compute_weighted_loss( 68920d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower unweighted_loss, weights=weights, reduction=self._loss_reduction) 69099d51b8da87ba462f66a6be90212677f2cae9e32A. Unique TensorFlower return LossSpec( 69120d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower training_loss=training_loss, 69220d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower unreduced_loss=unweighted_loss, 69320d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower weights=weights, 6948d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower processed_labels=label_ids) 6958d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower 69634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower def create_estimator_spec( 697c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower self, features, mode, logits, labels=None, train_op_fn=None, 698c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower regularization_losses=None): 6997948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower """Returns an `EstimatorSpec`. 7007948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower 7017948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower Args: 7027948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower features: Input `dict` of `Tensor` or `SparseTensor` objects. 7037948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower mode: Estimator's `ModeKeys`. 7047948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower logits: logits `Tensor` with shape `[D0, D1, ... DN, logits_dimension]`. 7057948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower For many applications, the shape is `[batch_size, logits_dimension]`. 7067948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower labels: Labels integer or string `Tensor` with shape matching `logits`, 7073f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower namely `[D0, D1, ... DN, 1]` or `[D0, D1, ... DN]`. `labels` is 7083f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower required argument when `mode` equals `TRAIN` or `EVAL`. 7097948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower train_op_fn: Function that takes a scalar loss `Tensor` and returns 7107948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower `train_op`. Required in TRAIN mode. 711c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower regularization_losses: A list of additional scalar losses to be added to 712c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower the training loss, such as regularization losses. These losses are 713c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower usually expressed as a batch average, so for best results users need to 7143f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower set `loss_reduction=SUM_OVER_BATCH_SIZE` or 7153f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower `loss_reduction=SUM_OVER_NONZERO_WEIGHTS` when creating the head to 716c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower avoid scaling errors. 7177948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower Returns: 7187948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower `EstimatorSpec`. 7197948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower Raises: 7207948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower ValueError: If `train_op_fn` is `None` in TRAIN mode. 7217948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower """ 72201c76110eb3cb1c378c9d7a14ca9f838bad6c7d1A. Unique TensorFlower with ops.name_scope(self._name, 'head'): 7237948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower logits = _check_logits_final_dim(logits, self.logits_dimension) 72434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 72534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower # Predict. 72634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower pred_keys = prediction_keys.PredictionKeys 72734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower with ops.name_scope(None, 'predictions', (logits,)): 7287948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower # class_ids's shape is [D0, D1, ... DN]. 7297948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower class_ids = math_ops.argmax(logits, axis=-1, name=pred_keys.CLASS_IDS) 7307948641c2c9426a2d4b3baeee3ba27cb75680a9eA. Unique TensorFlower class_ids = array_ops.expand_dims(class_ids, axis=-1) 731ce21f5e372360451528c9716a743b65308421423Mustafa Ispir if self._label_vocabulary: 732ce21f5e372360451528c9716a743b65308421423Mustafa Ispir table = lookup_ops.index_to_string_table_from_tensor( 733ce21f5e372360451528c9716a743b65308421423Mustafa Ispir vocabulary_list=self._label_vocabulary, 734ce21f5e372360451528c9716a743b65308421423Mustafa Ispir name='class_string_lookup') 735ce21f5e372360451528c9716a743b65308421423Mustafa Ispir classes = table.lookup(class_ids) 736ce21f5e372360451528c9716a743b65308421423Mustafa Ispir else: 737ce21f5e372360451528c9716a743b65308421423Mustafa Ispir classes = string_ops.as_string(class_ids, name='str_classes') 738ce21f5e372360451528c9716a743b65308421423Mustafa Ispir 73934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower probabilities = nn.softmax(logits, name=pred_keys.PROBABILITIES) 74034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower predictions = { 74134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower pred_keys.LOGITS: logits, 74234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower pred_keys.PROBABILITIES: probabilities, 74334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower # Expand to [batch_size, 1] 744ce21f5e372360451528c9716a743b65308421423Mustafa Ispir pred_keys.CLASS_IDS: class_ids, 745ce21f5e372360451528c9716a743b65308421423Mustafa Ispir pred_keys.CLASSES: classes, 74634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower } 74734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower if mode == model_fn.ModeKeys.PREDICT: 748df22cf83a21b62838ecf6f3a1c8a9c30ab20d482A. Unique TensorFlower classifier_output = _classification_output( 749df22cf83a21b62838ecf6f3a1c8a9c30ab20d482A. Unique TensorFlower scores=probabilities, n_classes=self._n_classes, 750df22cf83a21b62838ecf6f3a1c8a9c30ab20d482A. Unique TensorFlower label_vocabulary=self._label_vocabulary) 75134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower return model_fn.EstimatorSpec( 75234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower mode=model_fn.ModeKeys.PREDICT, 75334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower predictions=predictions, 754ce21f5e372360451528c9716a743b65308421423Mustafa Ispir export_outputs={ 7554db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel _DEFAULT_SERVING_KEY: classifier_output, 7564db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel _CLASSIFY_SERVING_KEY: classifier_output, 7574db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel _PREDICT_SERVING_KEY: export_output.PredictOutput(predictions) 758ce21f5e372360451528c9716a743b65308421423Mustafa Ispir }) 75934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 76020d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower training_loss, unreduced_loss, weights, label_ids = self.create_loss( 7618d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower features=features, mode=mode, logits=logits, labels=labels) 762c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower if regularization_losses: 763c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower regularization_loss = math_ops.add_n(regularization_losses) 764c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower regularized_training_loss = math_ops.add_n( 765c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower [training_loss, regularization_loss]) 766c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower else: 767c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower regularization_loss = None 768c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower regularized_training_loss = training_loss 76999d51b8da87ba462f66a6be90212677f2cae9e32A. Unique TensorFlower # Eval. 77034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower if mode == model_fn.ModeKeys.EVAL: 77134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower return model_fn.EstimatorSpec( 77234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower mode=model_fn.ModeKeys.EVAL, 77334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower predictions=predictions, 774c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower loss=regularized_training_loss, 77534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower eval_metric_ops=self._eval_metric_ops( 776ce21f5e372360451528c9716a743b65308421423Mustafa Ispir labels=label_ids, 77734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower class_ids=class_ids, 77820d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower weights=weights, 779c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower unreduced_loss=unreduced_loss, 780c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower regularization_loss=regularization_loss)) 78134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 78234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower # Train. 78334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower if train_op_fn is None: 7849c40507f80434058f600ebebb8b9d6971dd0bdb4Petros Mol raise ValueError('train_op_fn cannot be None.') 78520d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower # Only summarize mean_loss for SUM reduction to preserve backwards 78620d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower # compatibility. Otherwise skip it to avoid unnecessary computation. 78720d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower if self._loss_reduction == losses.Reduction.SUM: 78820d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower example_weight_sum = math_ops.reduce_sum( 78920d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower weights * array_ops.ones_like(unreduced_loss)) 79020d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower mean_loss = training_loss / example_weight_sum 79120d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower else: 79220d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower mean_loss = None 793169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir with ops.name_scope(''): 794c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower keys = metric_keys.MetricKeys 7950fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower summary.scalar( 796c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower _summary_key(self._name, keys.LOSS), 797c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower regularized_training_loss) 79820d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower if mean_loss is not None: 79920d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower summary.scalar( 800c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower _summary_key(self._name, keys.LOSS_MEAN), 80120d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower mean_loss) 802c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower if regularization_loss is not None: 803c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower summary.scalar( 804c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower _summary_key(self._name, keys.LOSS_REGULARIZATION), 805c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower regularization_loss) 806169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir return model_fn.EstimatorSpec( 807169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir mode=model_fn.ModeKeys.TRAIN, 808169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir predictions=predictions, 809c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower loss=regularized_training_loss, 810c7e5b6878bc8867ca6828f8d844ff72a1c96aac1A. Unique TensorFlower train_op=train_op_fn(regularized_training_loss)) 81134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 81234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 81334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerdef _binary_logistic_head_with_sigmoid_cross_entropy_loss( 8143942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower weight_column=None, 8153942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower thresholds=None, 8163942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower label_vocabulary=None, 8173942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower loss_reduction=losses.Reduction.SUM, 8183942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower loss_fn=None, 8193942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower name=None): 820d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower """Creates a `_Head` for single label binary classification. 82134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 82234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower This head uses `sigmoid_cross_entropy_with_logits` loss. 82334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 824b992761e569ac505e156127b64b91eba2f0953cbA. Unique TensorFlower The head expects `logits` with shape `[D0, D1, ... DN, 1]`. 825b992761e569ac505e156127b64b91eba2f0953cbA. Unique TensorFlower In many applications, the shape is `[batch_size, 1]`. 826b992761e569ac505e156127b64b91eba2f0953cbA. Unique TensorFlower 827b992761e569ac505e156127b64b91eba2f0953cbA. Unique TensorFlower `labels` must be a dense `Tensor` with shape matching `logits`, namely 828b992761e569ac505e156127b64b91eba2f0953cbA. Unique TensorFlower `[D0, D1, ... DN, 1]`. If `label_vocabulary` given, `labels` must be a string 829b992761e569ac505e156127b64b91eba2f0953cbA. Unique TensorFlower `Tensor` with values from the vocabulary. If `label_vocabulary` is not given, 830b992761e569ac505e156127b64b91eba2f0953cbA. Unique TensorFlower `labels` must be float `Tensor` with values in the interval `[0, 1]`. 831b992761e569ac505e156127b64b91eba2f0953cbA. Unique TensorFlower 832b992761e569ac505e156127b64b91eba2f0953cbA. Unique TensorFlower If `weight_column` is specified, weights must be of shape 833b992761e569ac505e156127b64b91eba2f0953cbA. Unique TensorFlower `[D0, D1, ... DN]`, or `[D0, D1, ... DN, 1]`. 834b992761e569ac505e156127b64b91eba2f0953cbA. Unique TensorFlower 835b992761e569ac505e156127b64b91eba2f0953cbA. Unique TensorFlower The loss is the weighted sum over the input dimensions. Namely, if the input 836b992761e569ac505e156127b64b91eba2f0953cbA. Unique TensorFlower labels have shape `[batch_size, 1]`, the loss is the weighted sum over 837b992761e569ac505e156127b64b91eba2f0953cbA. Unique TensorFlower `batch_size`. 83834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 8393942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower Also supports custom `loss_fn`. `loss_fn` takes `(labels, logits)` or 8403942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower `(labels, logits, features)` as arguments and returns unreduced loss with 8413942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower shape `[D0, D1, ... DN, 1]`. `loss_fn` must support float `labels` with 8423942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower shape `[D0, D1, ... DN, 1]`. Namely, the head applies `label_vocabulary` to 8433942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower the input labels before passing them to `loss_fn`. 8443942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower 84534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower Args: 846d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispir weight_column: A string or a `_NumericColumn` created by 847d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispir `tf.feature_column.numeric_column` defining feature column representing 84834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower weights. It is used to down weight or boost examples during training. It 84934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower will be multiplied by the loss of the example. 85034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower thresholds: Iterable of floats in the range `(0, 1)`. For binary 85134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower classification metrics such as precision and recall, an eval metric is 85234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower generated for each threshold value. This threshold is applied to the 85334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower logistic values to determine the binary classification (i.e., above the 85434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower threshold is `true`, below is `false`. 8559c40507f80434058f600ebebb8b9d6971dd0bdb4Petros Mol label_vocabulary: A list or tuple of strings representing possible label 8569c40507f80434058f600ebebb8b9d6971dd0bdb4Petros Mol values. If it is not given, that means labels are already encoded within 8579c40507f80434058f600ebebb8b9d6971dd0bdb4Petros Mol [0, 1]. If given, labels must be string type and have any value in 8589c40507f80434058f600ebebb8b9d6971dd0bdb4Petros Mol `label_vocabulary`. Note that errors will be raised if `label_vocabulary` 8599c40507f80434058f600ebebb8b9d6971dd0bdb4Petros Mol is not provided but labels are strings. 86020d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to 86120d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower reduce training loss over batch. Defaults to `SUM`. 8623942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower loss_fn: Optional loss function. 863abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower name: name of the head. If provided, summary and metrics keys will be 86401c76110eb3cb1c378c9d7a14ca9f838bad6c7d1A. Unique TensorFlower suffixed by `"/" + name`. Also used as `name_scope` when creating ops. 86534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 86634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower Returns: 867d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower An instance of `_Head` for binary classification. 86834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 86934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower Raises: 87020d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower ValueError: If `thresholds` contains a value outside of `(0, 1)`. 87120d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower ValueError: If `loss_reduction` is invalid. 87234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower """ 87379099d67761b3e56d1c3764cd34f97401571a211A. Unique TensorFlower thresholds = tuple(thresholds) if thresholds else tuple() 874edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir if label_vocabulary is not None and not isinstance(label_vocabulary, 875edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir (list, tuple)): 8769c40507f80434058f600ebebb8b9d6971dd0bdb4Petros Mol raise ValueError( 8779c40507f80434058f600ebebb8b9d6971dd0bdb4Petros Mol 'label_vocabulary should be a list or tuple. Given type: {}'.format( 8789c40507f80434058f600ebebb8b9d6971dd0bdb4Petros Mol type(label_vocabulary))) 879edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir 88034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower for threshold in thresholds: 88134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower if (threshold <= 0.0) or (threshold >= 1.0): 8829c40507f80434058f600ebebb8b9d6971dd0bdb4Petros Mol raise ValueError('thresholds not in (0, 1): {}.'.format((thresholds,))) 88320d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower if (loss_reduction not in losses.Reduction.all() or 88420d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower loss_reduction == losses.Reduction.NONE): 88520d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower raise ValueError('Invalid loss_reduction: {}'.format(loss_reduction)) 8863942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower if loss_fn: 8873942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower _validate_loss_fn_args(loss_fn) 88834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower return _BinaryLogisticHeadWithSigmoidCrossEntropyLoss( 889d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispir weight_column=weight_column, 890edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir thresholds=thresholds, 8910fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower label_vocabulary=label_vocabulary, 89220d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower loss_reduction=loss_reduction, 8933942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower loss_fn=loss_fn, 894abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower name=name) 89534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 89634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 89734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerclass _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head): 89834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower """See `_binary_logistic_head_with_sigmoid_cross_entropy_loss`.""" 89934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 9000fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower def __init__(self, 9010fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower weight_column=None, 9020fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower thresholds=None, 9030fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower label_vocabulary=None, 90420d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower loss_reduction=losses.Reduction.SUM, 9053942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower loss_fn=None, 906abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower name=None): 907d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispir self._weight_column = weight_column 90879099d67761b3e56d1c3764cd34f97401571a211A. Unique TensorFlower self._thresholds = thresholds 909edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir self._label_vocabulary = label_vocabulary 91020d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower self._loss_reduction = loss_reduction 9113942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower self._loss_fn = loss_fn 912abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower self._name = name 913abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower 914abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower @property 915abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower def name(self): 916abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower return self._name 91734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 91834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower @property 91934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower def logits_dimension(self): 92034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower return 1 92134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 92299d51b8da87ba462f66a6be90212677f2cae9e32A. Unique TensorFlower def _eval_metric_ops(self, labels, logits, logistic, class_ids, weights, 9233f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower unreduced_loss, regularization_loss): 92499d51b8da87ba462f66a6be90212677f2cae9e32A. Unique TensorFlower with ops.name_scope(None, 'metrics', 92599d51b8da87ba462f66a6be90212677f2cae9e32A. Unique TensorFlower (labels, logits, logistic, class_ids, weights, 9263f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower unreduced_loss, regularization_loss)): 92734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower keys = metric_keys.MetricKeys 92834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower labels_mean = _indicator_labels_mean( 92934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower labels=labels, weights=weights, name=keys.LABEL_MEAN) 93034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower metric_ops = { 93134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower # Estimator already adds a metric for loss. 932abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower _summary_key(self._name, keys.LOSS_MEAN): 933edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir metrics_lib.mean( 93420d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower values=unreduced_loss, 93520d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower weights=weights, 93699d51b8da87ba462f66a6be90212677f2cae9e32A. Unique TensorFlower name=keys.LOSS_MEAN), 937abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower _summary_key(self._name, keys.ACCURACY): 938edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir metrics_lib.accuracy( 939edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir labels=labels, 940edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir predictions=class_ids, 941edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir weights=weights, 942edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir name=keys.ACCURACY), 943abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower _summary_key(self._name, keys.PREDICTION_MEAN): 944edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir _predictions_mean( 945edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir predictions=logistic, 946edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir weights=weights, 947edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir name=keys.PREDICTION_MEAN), 948abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower _summary_key(self._name, keys.LABEL_MEAN): 949edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir labels_mean, 950abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower _summary_key(self._name, keys.ACCURACY_BASELINE): 951edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir _accuracy_baseline(labels_mean), 952abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower _summary_key(self._name, keys.AUC): 953edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir _auc( 954edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir labels=labels, 955edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir predictions=logistic, 956edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir weights=weights, 957edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir name=keys.AUC), 958abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower _summary_key(self._name, keys.AUC_PR): 959edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir _auc( 960edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir labels=labels, 961edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir predictions=logistic, 962edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir weights=weights, 963edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir curve='PR', 964edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir name=keys.AUC_PR) 96534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower } 9663f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower if regularization_loss is not None: 9673f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower metric_ops[_summary_key(self._name, keys.LOSS_REGULARIZATION)] = ( 9683f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower metrics_lib.mean( 9693f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower values=regularization_loss, 9703f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower name=keys.LOSS_REGULARIZATION)) 97134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower for threshold in self._thresholds: 97234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower accuracy_key = keys.ACCURACY_AT_THRESHOLD % threshold 973abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower metric_ops[_summary_key(self._name, 9740fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower accuracy_key)] = _accuracy_at_threshold( 9750fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower labels=labels, 9760fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower predictions=logistic, 9770fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower weights=weights, 9780fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower threshold=threshold, 9790fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower name=accuracy_key) 98034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower # Precision for positive examples. 98134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower precision_key = keys.PRECISION_AT_THRESHOLD % threshold 982abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower metric_ops[_summary_key(self._name, 9830fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower precision_key)] = _precision_at_threshold( 9840fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower labels=labels, 9850fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower predictions=logistic, 9860fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower weights=weights, 9870fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower threshold=threshold, 9880fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower name=precision_key) 98934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower # Recall for positive examples. 99034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower recall_key = keys.RECALL_AT_THRESHOLD % threshold 991abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower metric_ops[_summary_key(self._name, 9920fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower recall_key)] = _recall_at_threshold( 9930fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower labels=labels, 9940fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower predictions=logistic, 9950fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower weights=weights, 9960fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower threshold=threshold, 9970fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower name=recall_key) 99834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower return metric_ops 99934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 10008d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower def create_loss(self, features, mode, logits, labels): 10018d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower """See `Head`.""" 100299d51b8da87ba462f66a6be90212677f2cae9e32A. Unique TensorFlower del mode # Unused for this head. 1003b992761e569ac505e156127b64b91eba2f0953cbA. Unique TensorFlower logits = ops.convert_to_tensor(logits) 1004b992761e569ac505e156127b64b91eba2f0953cbA. Unique TensorFlower labels = _check_dense_labels_match_logits_and_reshape( 1005b992761e569ac505e156127b64b91eba2f0953cbA. Unique TensorFlower labels=labels, logits=logits, expected_labels_dimension=1) 10068d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower if self._label_vocabulary is not None: 10078d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower labels = lookup_ops.index_table_from_tensor( 10088d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower vocabulary_list=tuple(self._label_vocabulary), 10098d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower name='class_id_lookup').lookup(labels) 10108d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower labels = math_ops.to_float(labels) 10118d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower labels = _assert_range(labels, 2) 10123942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower if self._loss_fn: 10133942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower unweighted_loss = _call_loss_fn( 10143942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower loss_fn=self._loss_fn, labels=labels, logits=logits, 10153942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower features=features, expected_loss_dim=1) 10163942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower else: 10173942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower unweighted_loss = nn.sigmoid_cross_entropy_with_logits( 10183942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower labels=labels, logits=logits) 1019d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower weights = _get_weights_and_check_match_logits( 1020d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower features=features, weight_column=self._weight_column, logits=logits) 102120d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower training_loss = losses.compute_weighted_loss( 102220d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower unweighted_loss, weights=weights, reduction=self._loss_reduction) 102399d51b8da87ba462f66a6be90212677f2cae9e32A. Unique TensorFlower return LossSpec( 102420d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower training_loss=training_loss, 102520d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower unreduced_loss=unweighted_loss, 102620d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower weights=weights, 10278d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower processed_labels=labels) 10288d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower 102934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower def create_estimator_spec( 10303f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower self, features, mode, logits, labels=None, train_op_fn=None, 10313f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower regularization_losses=None): 10323f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower """Returns an `EstimatorSpec`. 10333f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower 10343f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower Args: 10353f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower features: Input `dict` of `Tensor` or `SparseTensor` objects. 10363f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower mode: Estimator's `ModeKeys`. 10373f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower logits: logits `Tensor` with shape `[D0, D1, ... DN, 1]`. For many 10383f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower applications, the shape is `[batch_size, 1]`. 10393f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower labels: Labels integer or string `Tensor` with shape matching `logits`, 10403f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower namely `[D0, D1, ... DN, 1]` or `[D0, D1, ... DN]`. `labels` is required 10413f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower argument when `mode` equals `TRAIN` or `EVAL`. 10423f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower train_op_fn: Function that takes a scalar loss `Tensor` and returns 10433f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower `train_op`. Required in TRAIN mode. 10443f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower regularization_losses: A list of additional scalar losses to be added to 10453f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower the training loss, such as regularization losses. These losses are 10463f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower usually expressed as a batch average, so for best results users need to 10473f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower set `loss_reduction=SUM_OVER_BATCH_SIZE` or 10483f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower `loss_reduction=SUM_OVER_NONZERO_WEIGHTS` when creating the head to 10493f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower avoid scaling errors. 10503f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower Returns: 10513f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower `EstimatorSpec`. 10523f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower Raises: 10533f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower ValueError: If `train_op_fn` is `None` in TRAIN mode. 10543f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower """ 1055169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir # Predict. 105601c76110eb3cb1c378c9d7a14ca9f838bad6c7d1A. Unique TensorFlower with ops.name_scope(self._name, 'head'): 1057169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir with ops.name_scope(None, 'predictions', (logits,)): 1058169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir pred_keys = prediction_keys.PredictionKeys 1059b992761e569ac505e156127b64b91eba2f0953cbA. Unique TensorFlower logits = _check_logits_final_dim(logits, self.logits_dimension) 1060169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir logistic = math_ops.sigmoid(logits, name=pred_keys.LOGISTIC) 1061169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir two_class_logits = array_ops.concat( 1062b992761e569ac505e156127b64b91eba2f0953cbA. Unique TensorFlower (array_ops.zeros_like(logits), logits), 1063b992761e569ac505e156127b64b91eba2f0953cbA. Unique TensorFlower axis=-1, name='two_class_logits') 1064df22cf83a21b62838ecf6f3a1c8a9c30ab20d482A. Unique TensorFlower probabilities = nn.softmax( 1065df22cf83a21b62838ecf6f3a1c8a9c30ab20d482A. Unique TensorFlower two_class_logits, name=pred_keys.PROBABILITIES) 1066b992761e569ac505e156127b64b91eba2f0953cbA. Unique TensorFlower class_ids = math_ops.argmax( 1067b992761e569ac505e156127b64b91eba2f0953cbA. Unique TensorFlower two_class_logits, axis=-1, name=pred_keys.CLASS_IDS) 1068b992761e569ac505e156127b64b91eba2f0953cbA. Unique TensorFlower class_ids = array_ops.expand_dims(class_ids, axis=-1) 1069169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir if self._label_vocabulary: 1070169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir table = lookup_ops.index_to_string_table_from_tensor( 1071169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir vocabulary_list=self._label_vocabulary, 1072169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir name='class_string_lookup') 1073169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir classes = table.lookup(class_ids) 1074169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir else: 1075169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir classes = string_ops.as_string(class_ids, name='str_classes') 1076169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir predictions = { 1077169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir pred_keys.LOGITS: logits, 1078169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir pred_keys.LOGISTIC: logistic, 1079df22cf83a21b62838ecf6f3a1c8a9c30ab20d482A. Unique TensorFlower pred_keys.PROBABILITIES: probabilities, 1080169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir pred_keys.CLASS_IDS: class_ids, 1081169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir pred_keys.CLASSES: classes, 1082169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir } 108334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower if mode == model_fn.ModeKeys.PREDICT: 1084df22cf83a21b62838ecf6f3a1c8a9c30ab20d482A. Unique TensorFlower classifier_output = _classification_output( 1085df22cf83a21b62838ecf6f3a1c8a9c30ab20d482A. Unique TensorFlower scores=probabilities, n_classes=2, 1086df22cf83a21b62838ecf6f3a1c8a9c30ab20d482A. Unique TensorFlower label_vocabulary=self._label_vocabulary) 108734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower return model_fn.EstimatorSpec( 108834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower mode=model_fn.ModeKeys.PREDICT, 108934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower predictions=predictions, 1090edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir export_outputs={ 10914db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel _DEFAULT_SERVING_KEY: classifier_output, 10924db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel _CLASSIFY_SERVING_KEY: classifier_output, 10934db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel _REGRESS_SERVING_KEY: export_output.RegressionOutput( 10944db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel value=logistic), 10954db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel _PREDICT_SERVING_KEY: export_output.PredictOutput(predictions) 1096edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir }) 109734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 109820d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower (training_loss, unreduced_loss, weights, processed_labels) = ( 109920d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower self.create_loss( 110020d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower features=features, mode=mode, logits=logits, labels=labels)) 11013f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower if regularization_losses: 11023f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower regularization_loss = math_ops.add_n(regularization_losses) 11033f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower regularized_training_loss = math_ops.add_n( 11043f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower [training_loss, regularization_loss]) 11053f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower else: 11063f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower regularization_loss = None 11073f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower regularized_training_loss = training_loss 110899d51b8da87ba462f66a6be90212677f2cae9e32A. Unique TensorFlower 110934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower # Eval. 111034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower if mode == model_fn.ModeKeys.EVAL: 111134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower return model_fn.EstimatorSpec( 111234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower mode=model_fn.ModeKeys.EVAL, 111334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower predictions=predictions, 11143f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower loss=regularized_training_loss, 111534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower eval_metric_ops=self._eval_metric_ops( 11168d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower labels=processed_labels, 111734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower logits=logits, 111834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower logistic=logistic, 1119edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir class_ids=class_ids, 1120b992761e569ac505e156127b64b91eba2f0953cbA. Unique TensorFlower weights=weights, 11213f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower unreduced_loss=unreduced_loss, 11223f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower regularization_loss=regularization_loss)) 112334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 112434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower # Train. 112534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower if train_op_fn is None: 112634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower raise ValueError('train_op_fn can not be None.') 112720d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower # Only summarize mean_loss for SUM reduction to preserve backwards 112820d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower # compatibility. Otherwise skip it to avoid unnecessary computation. 112920d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower if self._loss_reduction == losses.Reduction.SUM: 113020d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower example_weight_sum = math_ops.reduce_sum( 113120d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower weights * array_ops.ones_like(unreduced_loss)) 113220d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower mean_loss = training_loss / example_weight_sum 113320d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower else: 113420d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower mean_loss = None 1135169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir with ops.name_scope(''): 11363f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower keys = metric_keys.MetricKeys 11370fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower summary.scalar( 11383f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower _summary_key(self._name, keys.LOSS), 11393f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower regularized_training_loss) 114020d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower if mean_loss is not None: 114120d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower summary.scalar( 11423f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower _summary_key(self._name, keys.LOSS_MEAN), mean_loss) 11433f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower if regularization_loss is not None: 11443f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower summary.scalar( 11453f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower _summary_key(self._name, keys.LOSS_REGULARIZATION), 11463f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower regularization_loss) 1147169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir return model_fn.EstimatorSpec( 1148169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir mode=model_fn.ModeKeys.TRAIN, 1149169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir predictions=predictions, 11503f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower loss=regularized_training_loss, 11513f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower train_op=train_op_fn(regularized_training_loss)) 115234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 115334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 115420d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlowerdef _regression_head_with_mean_squared_error_loss( 115520d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower weight_column=None, 115620d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower label_dimension=1, 115720d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower loss_reduction=losses.Reduction.SUM, 11583942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower loss_fn=None, 115920d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower name=None): 1160d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower """Creates a `_Head` for regression using the `mean_squared_error` loss. 116134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 116267c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower The loss is the weighted sum over all input dimensions. Namely, if the input 116367c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower labels have shape `[batch_size, label_dimension]`, the loss is the weighted 116467c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower sum over both `batch_size` and `label_dimension`. 116567c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower 116667c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower The head expects `logits` with shape `[D0, D1, ... DN, label_dimension]`. 116767c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower In many applications, the shape is `[batch_size, label_dimension]`. 116867c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower 116967c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower The `labels` shape must match `logits`, namely 117067c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower `[D0, D1, ... DN, label_dimension]`. If `label_dimension=1`, shape 117167c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower `[D0, D1, ... DN]` is also supported. 117267c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower 117367c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower If `weight_column` is specified, weights must be of shape 117467c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower `[D0, D1, ... DN]`, `[D0, D1, ... DN, 1]` or 117567c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower `[D0, D1, ... DN, label_dimension]`. 117667c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower 11773942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower Also supports custom `loss_fn`. `loss_fn` takes `(labels, logits)` or 11783942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower `(labels, logits, features)` as arguments and returns unreduced loss with 11793942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower shape `[D0, D1, ... DN, label_dimension]`. 11803942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower 118134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower Args: 1182d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispir weight_column: A string or a `_NumericColumn` created by 1183d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispir `tf.feature_column.numeric_column` defining feature column representing 118434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower weights. It is used to down weight or boost examples during training. It 118534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower will be multiplied by the loss of the example. 118634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower label_dimension: Number of regression labels per example. This is the size 118734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower of the last dimension of the labels `Tensor` (typically, this has shape 118834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower `[batch_size, label_dimension]`). 118920d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to 119020d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower reduce training loss over batch. Defaults to `SUM`. 11913942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower loss_fn: Optional loss function. 1192abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower name: name of the head. If provided, summary and metrics keys will be 119301c76110eb3cb1c378c9d7a14ca9f838bad6c7d1A. Unique TensorFlower suffixed by `"/" + name`. Also used as `name_scope` when creating ops. 119434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 119534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower Returns: 119634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower An instance of `_Head` for linear regression. 119720d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower 119820d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower Raises: 119920d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower ValueError: If `label_dimension` or `loss_reduction` is invalid. 120034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower """ 120120d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower if (loss_reduction not in losses.Reduction.all() or 120220d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower loss_reduction == losses.Reduction.NONE): 120320d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower raise ValueError('Invalid loss_reduction: {}'.format(loss_reduction)) 12043942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower if loss_fn: 12053942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower _validate_loss_fn_args(loss_fn) 120634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower return _RegressionHeadWithMeanSquaredErrorLoss( 12070fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower weight_column=weight_column, 12080fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower label_dimension=label_dimension, 120920d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower loss_reduction=loss_reduction, 12103942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower loss_fn=loss_fn, 1211abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower name=name) 121234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 121334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 121434e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlowerclass _RegressionHeadWithMeanSquaredErrorLoss(_Head): 121534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower """`Head` for regression using the mean squared loss.""" 121634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 121720d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower def __init__( 121820d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower self, 121920d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower label_dimension, 122020d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower weight_column=None, 122120d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower loss_reduction=losses.Reduction.SUM, 12223942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower loss_fn=None, 122320d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower name=None): 1224d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispir """`Head` for regression.""" 122534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower if label_dimension < 1: 122634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower raise ValueError('Invalid label_dimension %s.' % label_dimension) 122734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower self._logits_dimension = label_dimension 1228d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispir self._weight_column = weight_column 122920d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower self._loss_reduction = loss_reduction 12303942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower self._loss_fn = loss_fn 1231abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower self._name = name 1232abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower 1233abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower @property 1234abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower def name(self): 1235abf4aa037f21231594c44fc08fb50607de0288b7A. Unique TensorFlower return self._name 123634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 123734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower @property 123834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower def logits_dimension(self): 123934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower return self._logits_dimension 124034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 12418d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower def create_loss(self, features, mode, logits, labels): 12428d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower """See `Head`.""" 124399d51b8da87ba462f66a6be90212677f2cae9e32A. Unique TensorFlower del mode # Unused for this head. 124467c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower logits = ops.convert_to_tensor(logits) 124567c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower labels = _check_dense_labels_match_logits_and_reshape( 124667c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower labels=labels, logits=logits, 124767c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower expected_labels_dimension=self._logits_dimension) 124827a7e5cfdb4ef9d5e3b710873c428cb44630622aA. Unique TensorFlower labels = math_ops.to_float(labels) 12493942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower if self._loss_fn: 12503942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower unweighted_loss = _call_loss_fn( 12513942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower loss_fn=self._loss_fn, labels=labels, logits=logits, 12523942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower features=features, expected_loss_dim=self._logits_dimension) 12533942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower else: 12543942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower unweighted_loss = losses.mean_squared_error( 12553942de820958e75792c86a50084f9312b5edd3baA. Unique TensorFlower labels=labels, predictions=logits, reduction=losses.Reduction.NONE) 1256d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower weights = _get_weights_and_check_match_logits( 1257d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower features=features, weight_column=self._weight_column, logits=logits, 1258d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower allow_per_logit_weights=True) 125920d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower training_loss = losses.compute_weighted_loss( 126020d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower unweighted_loss, weights=weights, reduction=self._loss_reduction) 126199d51b8da87ba462f66a6be90212677f2cae9e32A. Unique TensorFlower return LossSpec( 126220d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower training_loss=training_loss, 126320d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower unreduced_loss=unweighted_loss, 126420d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower weights=weights, 12658d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower processed_labels=labels) 12668d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower 126734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower def create_estimator_spec( 12683f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower self, features, mode, logits, labels=None, train_op_fn=None, 12693f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower regularization_losses=None): 127067c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower """Returns an `EstimatorSpec`. 127167c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower 127267c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower Args: 127367c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower features: Input `dict` of `Tensor` or `SparseTensor` objects. 127467c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower mode: Estimator's `ModeKeys`. 127567c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower logits: logits `Tensor` with shape `[D0, D1, ... DN, logits_dimension]`. 127667c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower For many applications, the shape is `[batch_size, logits_dimension]`. 127767c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower labels: Labels `Tensor` with shape matching `logits`, namely 127867c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower `[D0, D1, ... DN, logits_dimension]`. When `logits_dimension=1`, shape 127967c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower `[D0, D1, ... DN]` is also supported. `labels` is required argument when 128067c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower `mode` equals `TRAIN` or `EVAL`. 128167c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower train_op_fn: Function that takes a scalar loss `Tensor` and returns 128267c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower `train_op`. Required in TRAIN mode. 12833f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower regularization_losses: A list of additional scalar losses to be added to 12843f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower the training loss, such as regularization losses. These losses are 12853f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower usually expressed as a batch average, so for best results users need to 12863f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower set `loss_reduction=SUM_OVER_BATCH_SIZE` or 12873f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower `loss_reduction=SUM_OVER_NONZERO_WEIGHTS` when creating the head to 12883f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower avoid scaling errors. 128967c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower Returns: 129067c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower `EstimatorSpec`. 129167c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower Raises: 129267c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower ValueError: If `train_op_fn` is `None` in TRAIN mode. 129367c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower """ 1294169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir # Predict. 129501c76110eb3cb1c378c9d7a14ca9f838bad6c7d1A. Unique TensorFlower with ops.name_scope(self._name, 'head'): 129667c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower logits = _check_logits_final_dim(logits, self._logits_dimension) 129734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower predictions = {prediction_keys.PredictionKeys.PREDICTIONS: logits} 129834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower if mode == model_fn.ModeKeys.PREDICT: 12994db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel regression_output = export_output.RegressionOutput(value=logits) 130034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower return model_fn.EstimatorSpec( 130134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower mode=model_fn.ModeKeys.PREDICT, 130234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower predictions=predictions, 13034db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel export_outputs={ 13044db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel _DEFAULT_SERVING_KEY: regression_output, 13054db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel _REGRESS_SERVING_KEY: regression_output, 13064db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel _PREDICT_SERVING_KEY: export_output.PredictOutput(predictions) 13074db19c158148ed7d95e8b7f7f56050a82f76bec6David Soergel }) 130834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 130920d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower training_loss, unreduced_loss, weights, _ = self.create_loss( 13108d75705ffa7afa259a7e7d2aa9b7f05a559d4602A. Unique TensorFlower features=features, mode=mode, logits=logits, labels=labels) 13113f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower if regularization_losses: 13123f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower regularization_loss = math_ops.add_n(regularization_losses) 13133f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower regularized_training_loss = math_ops.add_n( 13143f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower [training_loss, regularization_loss]) 13153f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower else: 13163f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower regularization_loss = None 13173f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower regularized_training_loss = training_loss 131899d51b8da87ba462f66a6be90212677f2cae9e32A. Unique TensorFlower 131999d51b8da87ba462f66a6be90212677f2cae9e32A. Unique TensorFlower # Eval. 132034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower if mode == model_fn.ModeKeys.EVAL: 13213f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower keys = metric_keys.MetricKeys 132234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower # Estimator already adds a metric for loss. 132334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower eval_metric_ops = { 13243f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower _summary_key(self._name, keys.LOSS_MEAN): 132599d51b8da87ba462f66a6be90212677f2cae9e32A. Unique TensorFlower metrics_lib.mean( 132620d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower values=unreduced_loss, 132720d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower weights=weights) 132834e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower } 13293f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower if regularization_loss is not None: 13303f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower regularization_loss_key = _summary_key( 13313f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower self._name, keys.LOSS_REGULARIZATION) 13323f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower eval_metric_ops[regularization_loss_key] = metrics_lib.mean( 13333f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower values=regularization_loss, 13343f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower name=keys.LOSS_REGULARIZATION) 133534e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower return model_fn.EstimatorSpec( 133634e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower mode=model_fn.ModeKeys.EVAL, 133734e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower predictions=predictions, 13383f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower loss=regularized_training_loss, 133934e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower eval_metric_ops=eval_metric_ops) 134034e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower 134134e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower # Train. 134234e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower if train_op_fn is None: 134334e264a785cbd4cc1fcbf5d9cfc18c131fd5b206A. Unique TensorFlower raise ValueError('train_op_fn can not be None.') 134420d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower # Only summarize mean_loss for SUM reduction to preserve backwards 134520d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower # compatibility. Otherwise skip it to avoid unnecessary computation. 134620d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower if self._loss_reduction == losses.Reduction.SUM: 134720d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower example_weight_sum = math_ops.reduce_sum( 134820d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower weights * array_ops.ones_like(unreduced_loss)) 134920d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower mean_loss = training_loss / example_weight_sum 135020d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower else: 135120d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower mean_loss = None 1352169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir with ops.name_scope(''): 13533f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower keys = metric_keys.MetricKeys 13540fab1a5d397d3d44eb4df84b1a81b30677873e8fA. Unique TensorFlower summary.scalar( 13553f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower _summary_key(self._name, keys.LOSS), 13563f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower regularized_training_loss) 135720d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower if mean_loss is not None: 135820d3c083e3b039feb7310ef0402dfc6da8fd0c19A. Unique TensorFlower summary.scalar( 13593f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower _summary_key(self._name, keys.LOSS_MEAN), mean_loss) 13603f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower if regularization_loss is not None: 13613f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower summary.scalar( 13623f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower _summary_key(self._name, keys.LOSS_REGULARIZATION), 13633f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower regularization_loss) 1364169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir return model_fn.EstimatorSpec( 1365169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir mode=model_fn.ModeKeys.TRAIN, 1366169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir predictions=predictions, 13673f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower loss=regularized_training_loss, 13683f3c9f73e269d1a24b7bb2841d17ad6be3353b7eA. Unique TensorFlower train_op=train_op_fn(regularized_training_loss)) 1369edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir 1370edb5fed7fcb23f2f7ad8f556eb44c0a8213184caMustafa Ispir 1371d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlowerdef _assert_range(labels, n_classes, message=None): 1372169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir with ops.name_scope(None, 'assert_range', (labels,)): 1373169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir assert_less = check_ops.assert_less( 1374169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir labels, 1375169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir ops.convert_to_tensor(n_classes, dtype=labels.dtype), 1376d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower message=message or 'Label IDs must < n_classes') 1377169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir assert_greater = check_ops.assert_non_negative( 1378d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower labels, message=message or 'Label IDs must >= 0') 1379169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir with ops.control_dependencies((assert_less, assert_greater)): 1380169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir return array_ops.identity(labels) 1381d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispir 1382d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispir 1383d4f1845fc1aa57a2e872edc28d66901798205a85A. Unique TensorFlower# TODO(b/69000400): Delete this method. 1384d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispirdef _weights(features, weight_column): 1385d35cbbb4477486a28481b82a0e441dfdff78a780Mustafa Ispir """Fetches weights from features.""" 1386169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir with ops.name_scope(None, 'weights', values=features.values()): 1387169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir if weight_column is None: 1388169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir return 1. 1389169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir if isinstance(weight_column, six.string_types): 139067c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower weight_column = feature_column_lib.numeric_column( 139167c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower key=weight_column, shape=(1,)) 1392169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir if not isinstance(weight_column, feature_column_lib._NumericColumn): # pylint: disable=protected-access 1393169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir raise TypeError('Weight column must be either a string or _NumericColumn.' 1394169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir ' Given type: {}.'.format(type(weight_column))) 1395169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir weights = weight_column._get_dense_tensor( # pylint: disable=protected-access 1396169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir feature_column_lib._LazyBuilder(features)) # pylint: disable=protected-access 1397169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir if not (weights.dtype.is_floating or weights.dtype.is_integer): 1398169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir raise ValueError('Weight column should be castable to float. ' 1399169de9361e7463bb2389571644df2b6a67929a3eMustafa Ispir 'Given dtype: {}'.format(weights.dtype)) 140067c2ab669448828dc722af651917aa9abd01abf7A. Unique TensorFlower return math_ops.to_float(weights, name='weights') 1401