logistic_regressor.py revision 6f6d4af798cd77d0299f3880a1343558ccb03bcf
1b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho# 3b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho# Licensed under the Apache License, Version 2.0 (the "License"); 4b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho# you may not use this file except in compliance with the License. 5b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho# You may obtain a copy of the License at 6b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho# 7b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho# http://www.apache.org/licenses/LICENSE-2.0 8b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho# 9b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho# Unless required by applicable law or agreed to in writing, software 10b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho# distributed under the License is distributed on an "AS IS" BASIS, 11b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho# See the License for the specific language governing permissions and 13b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho# limitations under the License. 14b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho# ============================================================================== 15b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho"""Logistic regression (aka binary classifier) class. 16b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho 17b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei HoThis defines some useful basic metrics for using logistic regression to classify 18b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Hoa binary event (0 vs 1). 19b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho""" 20b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho 21b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Hofrom __future__ import absolute_import 22b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Hofrom __future__ import division 23b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Hofrom __future__ import print_function 24b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho 25b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Hofrom tensorflow.contrib import metrics as metrics_lib 26b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Hofrom tensorflow.contrib.learn.python.learn.estimators import estimator 276f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlowerfrom tensorflow.contrib.learn.python.learn.estimators import metric_key 286f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlowerfrom tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib 29b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Hofrom tensorflow.python.ops import math_ops 30b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho 31b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho 326f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlowerdef _get_model_fn_with_logistic_metrics(model_fn): 336f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower """Returns a model_fn with additional logistic metrics. 34b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho 356f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower Args: 366f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower model_fn: Model function with the signature: 376f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower `(features, labels, mode) -> (predictions, loss, train_op)`. 386f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower Expects the returned predictions to be probabilities in [0.0, 1.0]. 39b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho 406f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower Returns: 416f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower model_fn that can be used with Estimator. 426f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower """ 43b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho 446f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower def _model_fn(features, labels, mode, params): 456f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower """Model function that appends logistic evaluation metrics.""" 466f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower thresholds = params.get('thresholds') or [.5] 47b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho 486f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower predictions, loss, train_op = model_fn(features, labels, mode) 496f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower if mode == model_fn_lib.ModeKeys.EVAL: 506f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower eval_metric_ops = _make_logistic_eval_metric_ops( 516f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower labels=labels, 526f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower predictions=predictions, 536f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower thresholds=thresholds) 546f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower else: 556f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower eval_metric_ops = None 56b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho 576f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower return model_fn_lib.ModelFnOps( 586f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower mode=mode, 596f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower predictions=predictions, 606f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower loss=loss, 616f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower train_op=train_op, 626f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower eval_metric_ops=eval_metric_ops) 63b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho 646f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower return _model_fn 65b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho 66b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho 676f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower# TODO(roumposg): Deprecate and delete after converting users to use head. 686f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlowerdef LogisticRegressor( # pylint: disable=invalid-name 696f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower model_fn, thresholds=None, model_dir=None, config=None, 706f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower feature_engineering_fn=None): 716f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower """Builds a logistic regression Estimator for binary classification. 72641894039e740840816575103dc816541381da59A. Unique TensorFlower 736f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower This method provides a basic Estimator with some additional metrics for custom 74641894039e740840816575103dc816541381da59A. Unique TensorFlower binary classification models, including AUC, precision/recall and accuracy. 75641894039e740840816575103dc816541381da59A. Unique TensorFlower 76641894039e740840816575103dc816541381da59A. Unique TensorFlower Example: 77641894039e740840816575103dc816541381da59A. Unique TensorFlower 78641894039e740840816575103dc816541381da59A. Unique TensorFlower ```python 79641894039e740840816575103dc816541381da59A. Unique TensorFlower # See tf.contrib.learn.Estimator(...) for details on model_fn structure 80641894039e740840816575103dc816541381da59A. Unique TensorFlower def my_model_fn(...): 81641894039e740840816575103dc816541381da59A. Unique TensorFlower pass 82641894039e740840816575103dc816541381da59A. Unique TensorFlower 83641894039e740840816575103dc816541381da59A. Unique TensorFlower estimator = LogisticRegressor(model_fn=my_model_fn) 84641894039e740840816575103dc816541381da59A. Unique TensorFlower 85641894039e740840816575103dc816541381da59A. Unique TensorFlower # Input builders 86641894039e740840816575103dc816541381da59A. Unique TensorFlower def input_fn_train: 87641894039e740840816575103dc816541381da59A. Unique TensorFlower pass 88641894039e740840816575103dc816541381da59A. Unique TensorFlower 89641894039e740840816575103dc816541381da59A. Unique TensorFlower estimator.fit(input_fn=input_fn_train) 90641894039e740840816575103dc816541381da59A. Unique TensorFlower estimator.predict(x=x) 91641894039e740840816575103dc816541381da59A. Unique TensorFlower ``` 926f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower 936f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower Args: 946f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower model_fn: Model function with the signature: 956f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower `(features, labels, mode) -> (predictions, loss, train_op)`. 966f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower Expects the returned predictions to be probabilities in [0.0, 1.0]. 976f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower thresholds: List of floating point thresholds to use for accuracy, 986f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower precision, and recall metrics. If `None`, defaults to `[0.5]`. 996f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower model_dir: Directory to save model parameters, graphs, etc. This can also 1006f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower be used to load checkpoints from the directory into a estimator to 1016f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower continue training a previously saved model. 1026f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower config: A RunConfig configuration object. 1036f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower feature_engineering_fn: Feature engineering function. Takes features and 1046f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower labels which are the output of `input_fn` and 1056f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower returns features and labels which will be fed 1066f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower into the model. 1076f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower 1086f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower Returns: 1096f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower A `tf.contrib.learn.Estimator` instance. 110b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho """ 1116f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower return estimator.Estimator( 1126f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower model_fn=_get_model_fn_with_logistic_metrics(model_fn), 1136f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower model_dir=model_dir, 1146f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower config=config, 1156f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower params={'thresholds': thresholds}, 1166f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower feature_engineering_fn=feature_engineering_fn) 1176f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower 118b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho 1196f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlowerdef _make_logistic_eval_metric_ops(labels, predictions, thresholds): 1206f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower """Returns a dictionary of evaluation metric ops for logistic regression. 1216f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower 1226f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower Args: 1236f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower labels: The labels `Tensor`, or a dict with only one `Tensor` keyed by name. 1246f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower predictions: The predictions `Tensor`. 1256f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower thresholds: List of floating point thresholds to use for accuracy, 1266f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower precision, and recall metrics. 1276f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower 1286f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower Returns: 1296f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower A dict of metric results keyed by name. 1306f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower """ 1316f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower # If labels is a dict with a single key, unpack into a single tensor. 1326f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower labels_tensor = labels 1336f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower if isinstance(labels, dict) and len(labels) == 1: 1346f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower labels_tensor = labels.values()[0] 1356f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower 1366f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower metrics = {} 1376f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower metrics[metric_key.MetricKey.PREDICTION_MEAN] = metrics_lib.streaming_mean( 1386f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower predictions) 1396f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower metrics[metric_key.MetricKey.LABEL_MEAN] = metrics_lib.streaming_mean( 1406f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower labels_tensor) 1416f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower # Also include the streaming mean of the label as an accuracy baseline, as 1426f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower # a reminder to users. 1436f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower metrics[metric_key.MetricKey.ACCURACY_BASELINE] = metrics_lib.streaming_mean( 1446f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower labels_tensor) 1456f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower 1466f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower metrics[metric_key.MetricKey.AUC] = metrics_lib.streaming_auc( 1476f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower labels=labels_tensor, predictions=predictions) 1486f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower 1496f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower for threshold in thresholds: 1506f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower predictions_at_threshold = math_ops.to_float( 1516f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower math_ops.greater_equal(predictions, threshold), 1526f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower name='predictions_at_threshold_%f' % threshold) 1536f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower metrics[metric_key.MetricKey.ACCURACY_MEAN % threshold] = ( 1546f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower metrics_lib.streaming_accuracy(labels=labels_tensor, 1556f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower predictions=predictions_at_threshold)) 1566f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower # Precision for positive examples. 1576f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower metrics[metric_key.MetricKey.PRECISION_MEAN % threshold] = ( 1586f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower metrics_lib.streaming_precision(labels=labels_tensor, 1596f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower predictions=predictions_at_threshold)) 1606f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower # Recall for positive examples. 1616f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower metrics[metric_key.MetricKey.RECALL_MEAN % threshold] = ( 1626f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower metrics_lib.streaming_recall(labels=labels_tensor, 1636f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower predictions=predictions_at_threshold)) 1646f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower 1656f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower return metrics 166