1b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho# 3b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho# Licensed under the Apache License, Version 2.0 (the "License"); 4b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho# you may not use this file except in compliance with the License. 5b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho# You may obtain a copy of the License at 6b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho# 7b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho# http://www.apache.org/licenses/LICENSE-2.0 8b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho# 9b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho# Unless required by applicable law or agreed to in writing, software 10b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho# distributed under the License is distributed on an "AS IS" BASIS, 11b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho# See the License for the specific language governing permissions and 13b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho# limitations under the License. 14b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho# ============================================================================== 15b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho"""Logistic regression (aka binary classifier) class. 16b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho 17b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei HoThis defines some useful basic metrics for using logistic regression to classify 18b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Hoa binary event (0 vs 1). 19b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho""" 20b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho 21b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Hofrom __future__ import absolute_import 22b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Hofrom __future__ import division 23b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Hofrom __future__ import print_function 24b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho 25b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Hofrom tensorflow.contrib import metrics as metrics_lib 263f6e5288c8b9d367bf1939f0a89bfb7568e1f731A. Unique TensorFlowerfrom tensorflow.contrib.learn.python.learn.estimators import constants 27b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Hofrom tensorflow.contrib.learn.python.learn.estimators import estimator 286f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlowerfrom tensorflow.contrib.learn.python.learn.estimators import metric_key 296f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlowerfrom tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib 30b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Hofrom tensorflow.python.ops import math_ops 31b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho 32b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho 336f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlowerdef _get_model_fn_with_logistic_metrics(model_fn): 346f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower """Returns a model_fn with additional logistic metrics. 35b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho 366f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower Args: 376f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower model_fn: Model function with the signature: 386f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower `(features, labels, mode) -> (predictions, loss, train_op)`. 396f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower Expects the returned predictions to be probabilities in [0.0, 1.0]. 40b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho 416f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower Returns: 426f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower model_fn that can be used with Estimator. 436f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower """ 44b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho 456f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower def _model_fn(features, labels, mode, params): 466f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower """Model function that appends logistic evaluation metrics.""" 476f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower thresholds = params.get('thresholds') or [.5] 48b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho 496f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower predictions, loss, train_op = model_fn(features, labels, mode) 506f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower if mode == model_fn_lib.ModeKeys.EVAL: 516f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower eval_metric_ops = _make_logistic_eval_metric_ops( 526f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower labels=labels, 536f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower predictions=predictions, 546f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower thresholds=thresholds) 556f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower else: 566f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower eval_metric_ops = None 576f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower return model_fn_lib.ModelFnOps( 586f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower mode=mode, 596f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower predictions=predictions, 606f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower loss=loss, 616f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower train_op=train_op, 623f6e5288c8b9d367bf1939f0a89bfb7568e1f731A. Unique TensorFlower eval_metric_ops=eval_metric_ops, 633f6e5288c8b9d367bf1939f0a89bfb7568e1f731A. Unique TensorFlower output_alternatives={ 643f6e5288c8b9d367bf1939f0a89bfb7568e1f731A. Unique TensorFlower 'head': (constants.ProblemType.LOGISTIC_REGRESSION, { 653f6e5288c8b9d367bf1939f0a89bfb7568e1f731A. Unique TensorFlower 'predictions': predictions 663f6e5288c8b9d367bf1939f0a89bfb7568e1f731A. Unique TensorFlower }) 673f6e5288c8b9d367bf1939f0a89bfb7568e1f731A. Unique TensorFlower }) 68b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho 696f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower return _model_fn 70b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho 71b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho 726f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower# TODO(roumposg): Deprecate and delete after converting users to use head. 736f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlowerdef LogisticRegressor( # pylint: disable=invalid-name 746f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower model_fn, thresholds=None, model_dir=None, config=None, 756f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower feature_engineering_fn=None): 766f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower """Builds a logistic regression Estimator for binary classification. 77641894039e740840816575103dc816541381da59A. Unique TensorFlower 786f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower This method provides a basic Estimator with some additional metrics for custom 79641894039e740840816575103dc816541381da59A. Unique TensorFlower binary classification models, including AUC, precision/recall and accuracy. 80641894039e740840816575103dc816541381da59A. Unique TensorFlower 81641894039e740840816575103dc816541381da59A. Unique TensorFlower Example: 82641894039e740840816575103dc816541381da59A. Unique TensorFlower 83641894039e740840816575103dc816541381da59A. Unique TensorFlower ```python 84641894039e740840816575103dc816541381da59A. Unique TensorFlower # See tf.contrib.learn.Estimator(...) for details on model_fn structure 85641894039e740840816575103dc816541381da59A. Unique TensorFlower def my_model_fn(...): 86641894039e740840816575103dc816541381da59A. Unique TensorFlower pass 87641894039e740840816575103dc816541381da59A. Unique TensorFlower 88641894039e740840816575103dc816541381da59A. Unique TensorFlower estimator = LogisticRegressor(model_fn=my_model_fn) 89641894039e740840816575103dc816541381da59A. Unique TensorFlower 90641894039e740840816575103dc816541381da59A. Unique TensorFlower # Input builders 91641894039e740840816575103dc816541381da59A. Unique TensorFlower def input_fn_train: 92641894039e740840816575103dc816541381da59A. Unique TensorFlower pass 93641894039e740840816575103dc816541381da59A. Unique TensorFlower 94641894039e740840816575103dc816541381da59A. Unique TensorFlower estimator.fit(input_fn=input_fn_train) 95641894039e740840816575103dc816541381da59A. Unique TensorFlower estimator.predict(x=x) 96641894039e740840816575103dc816541381da59A. Unique TensorFlower ``` 976f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower 986f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower Args: 996f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower model_fn: Model function with the signature: 1006f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower `(features, labels, mode) -> (predictions, loss, train_op)`. 1016f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower Expects the returned predictions to be probabilities in [0.0, 1.0]. 1026f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower thresholds: List of floating point thresholds to use for accuracy, 1036f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower precision, and recall metrics. If `None`, defaults to `[0.5]`. 1046f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower model_dir: Directory to save model parameters, graphs, etc. This can also 1056f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower be used to load checkpoints from the directory into a estimator to 1066f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower continue training a previously saved model. 1076f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower config: A RunConfig configuration object. 1086f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower feature_engineering_fn: Feature engineering function. Takes features and 1096f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower labels which are the output of `input_fn` and 1106f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower returns features and labels which will be fed 1116f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower into the model. 1126f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower 1136f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower Returns: 1146161574ed26e053e3bf08a56ac161ce42985e12cTaehoon Lee An `Estimator` instance. 115b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho """ 1166f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower return estimator.Estimator( 1176f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower model_fn=_get_model_fn_with_logistic_metrics(model_fn), 1186f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower model_dir=model_dir, 1196f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower config=config, 1206f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower params={'thresholds': thresholds}, 1216f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower feature_engineering_fn=feature_engineering_fn) 1226f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower 123b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho 1246f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlowerdef _make_logistic_eval_metric_ops(labels, predictions, thresholds): 1256f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower """Returns a dictionary of evaluation metric ops for logistic regression. 1266f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower 1276f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower Args: 1286f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower labels: The labels `Tensor`, or a dict with only one `Tensor` keyed by name. 1296f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower predictions: The predictions `Tensor`. 1306f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower thresholds: List of floating point thresholds to use for accuracy, 1316f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower precision, and recall metrics. 1326f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower 1336f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower Returns: 1346f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower A dict of metric results keyed by name. 1356f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower """ 1366f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower # If labels is a dict with a single key, unpack into a single tensor. 1376f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower labels_tensor = labels 1386f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower if isinstance(labels, dict) and len(labels) == 1: 1396f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower labels_tensor = labels.values()[0] 1406f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower 1416f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower metrics = {} 1426f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower metrics[metric_key.MetricKey.PREDICTION_MEAN] = metrics_lib.streaming_mean( 1436f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower predictions) 1446f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower metrics[metric_key.MetricKey.LABEL_MEAN] = metrics_lib.streaming_mean( 1456f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower labels_tensor) 1466f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower # Also include the streaming mean of the label as an accuracy baseline, as 1476f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower # a reminder to users. 1486f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower metrics[metric_key.MetricKey.ACCURACY_BASELINE] = metrics_lib.streaming_mean( 1496f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower labels_tensor) 1506f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower 1516f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower metrics[metric_key.MetricKey.AUC] = metrics_lib.streaming_auc( 1526f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower labels=labels_tensor, predictions=predictions) 1536f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower 1546f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower for threshold in thresholds: 1556f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower predictions_at_threshold = math_ops.to_float( 1566f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower math_ops.greater_equal(predictions, threshold), 1576f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower name='predictions_at_threshold_%f' % threshold) 1586f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower metrics[metric_key.MetricKey.ACCURACY_MEAN % threshold] = ( 1596f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower metrics_lib.streaming_accuracy(labels=labels_tensor, 1606f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower predictions=predictions_at_threshold)) 1616f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower # Precision for positive examples. 1626f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower metrics[metric_key.MetricKey.PRECISION_MEAN % threshold] = ( 1636f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower metrics_lib.streaming_precision(labels=labels_tensor, 1646f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower predictions=predictions_at_threshold)) 1656f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower # Recall for positive examples. 1666f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower metrics[metric_key.MetricKey.RECALL_MEAN % threshold] = ( 1676f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower metrics_lib.streaming_recall(labels=labels_tensor, 1686f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower predictions=predictions_at_threshold)) 1696f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower 1706f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower return metrics 171