logistic_regressor.py revision 6f6d4af798cd77d0299f3880a1343558ccb03bcf
1b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho#
3b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho# Licensed under the Apache License, Version 2.0 (the "License");
4b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho# you may not use this file except in compliance with the License.
5b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho# You may obtain a copy of the License at
6b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho#
7b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho#     http://www.apache.org/licenses/LICENSE-2.0
8b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho#
9b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho# Unless required by applicable law or agreed to in writing, software
10b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho# distributed under the License is distributed on an "AS IS" BASIS,
11b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho# See the License for the specific language governing permissions and
13b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho# limitations under the License.
14b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho# ==============================================================================
15b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho"""Logistic regression (aka binary classifier) class.
16b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho
17b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei HoThis defines some useful basic metrics for using logistic regression to classify
18b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Hoa binary event (0 vs 1).
19b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho"""
20b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho
21b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Hofrom __future__ import absolute_import
22b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Hofrom __future__ import division
23b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Hofrom __future__ import print_function
24b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho
25b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Hofrom tensorflow.contrib import metrics as metrics_lib
26b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Hofrom tensorflow.contrib.learn.python.learn.estimators import estimator
276f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlowerfrom tensorflow.contrib.learn.python.learn.estimators import metric_key
286f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlowerfrom tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib
29b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Hofrom tensorflow.python.ops import math_ops
30b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho
31b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho
326f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlowerdef _get_model_fn_with_logistic_metrics(model_fn):
336f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower  """Returns a model_fn with additional logistic metrics.
34b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho
356f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower  Args:
366f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower    model_fn: Model function with the signature:
376f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower      `(features, labels, mode) -> (predictions, loss, train_op)`.
386f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower      Expects the returned predictions to be probabilities in [0.0, 1.0].
39b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho
406f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower  Returns:
416f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower    model_fn that can be used with Estimator.
426f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower  """
43b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho
446f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower  def _model_fn(features, labels, mode, params):
456f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower    """Model function that appends logistic evaluation metrics."""
466f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower    thresholds = params.get('thresholds') or [.5]
47b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho
486f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower    predictions, loss, train_op = model_fn(features, labels, mode)
496f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower    if mode == model_fn_lib.ModeKeys.EVAL:
506f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower      eval_metric_ops = _make_logistic_eval_metric_ops(
516f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower          labels=labels,
526f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower          predictions=predictions,
536f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower          thresholds=thresholds)
546f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower    else:
556f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower      eval_metric_ops = None
56b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho
576f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower    return model_fn_lib.ModelFnOps(
586f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower        mode=mode,
596f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower        predictions=predictions,
606f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower        loss=loss,
616f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower        train_op=train_op,
626f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower        eval_metric_ops=eval_metric_ops)
63b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho
646f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower  return _model_fn
65b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho
66b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho
676f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower# TODO(roumposg): Deprecate and delete after converting users to use head.
686f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlowerdef LogisticRegressor(  # pylint: disable=invalid-name
696f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower    model_fn, thresholds=None, model_dir=None, config=None,
706f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower    feature_engineering_fn=None):
716f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower  """Builds a logistic regression Estimator for binary classification.
72641894039e740840816575103dc816541381da59A. Unique TensorFlower
736f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower  This method provides a basic Estimator with some additional metrics for custom
74641894039e740840816575103dc816541381da59A. Unique TensorFlower  binary classification models, including AUC, precision/recall and accuracy.
75641894039e740840816575103dc816541381da59A. Unique TensorFlower
76641894039e740840816575103dc816541381da59A. Unique TensorFlower  Example:
77641894039e740840816575103dc816541381da59A. Unique TensorFlower
78641894039e740840816575103dc816541381da59A. Unique TensorFlower  ```python
79641894039e740840816575103dc816541381da59A. Unique TensorFlower    # See tf.contrib.learn.Estimator(...) for details on model_fn structure
80641894039e740840816575103dc816541381da59A. Unique TensorFlower    def my_model_fn(...):
81641894039e740840816575103dc816541381da59A. Unique TensorFlower      pass
82641894039e740840816575103dc816541381da59A. Unique TensorFlower
83641894039e740840816575103dc816541381da59A. Unique TensorFlower    estimator = LogisticRegressor(model_fn=my_model_fn)
84641894039e740840816575103dc816541381da59A. Unique TensorFlower
85641894039e740840816575103dc816541381da59A. Unique TensorFlower    # Input builders
86641894039e740840816575103dc816541381da59A. Unique TensorFlower    def input_fn_train:
87641894039e740840816575103dc816541381da59A. Unique TensorFlower      pass
88641894039e740840816575103dc816541381da59A. Unique TensorFlower
89641894039e740840816575103dc816541381da59A. Unique TensorFlower    estimator.fit(input_fn=input_fn_train)
90641894039e740840816575103dc816541381da59A. Unique TensorFlower    estimator.predict(x=x)
91641894039e740840816575103dc816541381da59A. Unique TensorFlower  ```
926f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower
936f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower  Args:
946f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower    model_fn: Model function with the signature:
956f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower      `(features, labels, mode) -> (predictions, loss, train_op)`.
966f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower      Expects the returned predictions to be probabilities in [0.0, 1.0].
976f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower    thresholds: List of floating point thresholds to use for accuracy,
986f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower      precision, and recall metrics. If `None`, defaults to `[0.5]`.
996f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower    model_dir: Directory to save model parameters, graphs, etc. This can also
1006f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower      be used to load checkpoints from the directory into a estimator to
1016f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower      continue training a previously saved model.
1026f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower    config: A RunConfig configuration object.
1036f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower    feature_engineering_fn: Feature engineering function. Takes features and
1046f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower                      labels which are the output of `input_fn` and
1056f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower                      returns features and labels which will be fed
1066f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower                      into the model.
1076f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower
1086f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower  Returns:
1096f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower    A `tf.contrib.learn.Estimator` instance.
110b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho  """
1116f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower  return estimator.Estimator(
1126f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower      model_fn=_get_model_fn_with_logistic_metrics(model_fn),
1136f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower      model_dir=model_dir,
1146f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower      config=config,
1156f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower      params={'thresholds': thresholds},
1166f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower      feature_engineering_fn=feature_engineering_fn)
1176f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower
118b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho
1196f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlowerdef _make_logistic_eval_metric_ops(labels, predictions, thresholds):
1206f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower  """Returns a dictionary of evaluation metric ops for logistic regression.
1216f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower
1226f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower  Args:
1236f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower    labels: The labels `Tensor`, or a dict with only one `Tensor` keyed by name.
1246f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower    predictions: The predictions `Tensor`.
1256f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower    thresholds: List of floating point thresholds to use for accuracy,
1266f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower      precision, and recall metrics.
1276f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower
1286f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower  Returns:
1296f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower    A dict of metric results keyed by name.
1306f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower  """
1316f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower  # If labels is a dict with a single key, unpack into a single tensor.
1326f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower  labels_tensor = labels
1336f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower  if isinstance(labels, dict) and len(labels) == 1:
1346f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower    labels_tensor = labels.values()[0]
1356f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower
1366f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower  metrics = {}
1376f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower  metrics[metric_key.MetricKey.PREDICTION_MEAN] = metrics_lib.streaming_mean(
1386f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower      predictions)
1396f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower  metrics[metric_key.MetricKey.LABEL_MEAN] = metrics_lib.streaming_mean(
1406f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower      labels_tensor)
1416f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower  # Also include the streaming mean of the label as an accuracy baseline, as
1426f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower  # a reminder to users.
1436f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower  metrics[metric_key.MetricKey.ACCURACY_BASELINE] = metrics_lib.streaming_mean(
1446f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower      labels_tensor)
1456f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower
1466f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower  metrics[metric_key.MetricKey.AUC] = metrics_lib.streaming_auc(
1476f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower      labels=labels_tensor, predictions=predictions)
1486f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower
1496f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower  for threshold in thresholds:
1506f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower    predictions_at_threshold = math_ops.to_float(
1516f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower        math_ops.greater_equal(predictions, threshold),
1526f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower        name='predictions_at_threshold_%f' % threshold)
1536f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower    metrics[metric_key.MetricKey.ACCURACY_MEAN % threshold] = (
1546f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower        metrics_lib.streaming_accuracy(labels=labels_tensor,
1556f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower                                       predictions=predictions_at_threshold))
1566f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower    # Precision for positive examples.
1576f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower    metrics[metric_key.MetricKey.PRECISION_MEAN % threshold] = (
1586f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower        metrics_lib.streaming_precision(labels=labels_tensor,
1596f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower                                        predictions=predictions_at_threshold))
1606f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower    # Recall for positive examples.
1616f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower    metrics[metric_key.MetricKey.RECALL_MEAN % threshold] = (
1626f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower        metrics_lib.streaming_recall(labels=labels_tensor,
1636f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower                                     predictions=predictions_at_threshold))
1646f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower
1656f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower  return metrics
166