1b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho#
3b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho# Licensed under the Apache License, Version 2.0 (the "License");
4b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho# you may not use this file except in compliance with the License.
5b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho# You may obtain a copy of the License at
6b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho#
7b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho#     http://www.apache.org/licenses/LICENSE-2.0
8b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho#
9b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho# Unless required by applicable law or agreed to in writing, software
10b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho# distributed under the License is distributed on an "AS IS" BASIS,
11b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho# See the License for the specific language governing permissions and
13b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho# limitations under the License.
14b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho# ==============================================================================
15b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho"""Logistic regression (aka binary classifier) class.
16b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho
17b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei HoThis defines some useful basic metrics for using logistic regression to classify
18b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Hoa binary event (0 vs 1).
19b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho"""
20b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho
21b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Hofrom __future__ import absolute_import
22b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Hofrom __future__ import division
23b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Hofrom __future__ import print_function
24b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho
25b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Hofrom tensorflow.contrib import metrics as metrics_lib
263f6e5288c8b9d367bf1939f0a89bfb7568e1f731A. Unique TensorFlowerfrom tensorflow.contrib.learn.python.learn.estimators import constants
27b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Hofrom tensorflow.contrib.learn.python.learn.estimators import estimator
286f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlowerfrom tensorflow.contrib.learn.python.learn.estimators import metric_key
296f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlowerfrom tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib
30b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Hofrom tensorflow.python.ops import math_ops
31b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho
32b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho
336f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlowerdef _get_model_fn_with_logistic_metrics(model_fn):
346f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower  """Returns a model_fn with additional logistic metrics.
35b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho
366f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower  Args:
376f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower    model_fn: Model function with the signature:
386f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower      `(features, labels, mode) -> (predictions, loss, train_op)`.
396f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower      Expects the returned predictions to be probabilities in [0.0, 1.0].
40b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho
416f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower  Returns:
426f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower    model_fn that can be used with Estimator.
436f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower  """
44b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho
456f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower  def _model_fn(features, labels, mode, params):
466f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower    """Model function that appends logistic evaluation metrics."""
476f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower    thresholds = params.get('thresholds') or [.5]
48b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho
496f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower    predictions, loss, train_op = model_fn(features, labels, mode)
506f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower    if mode == model_fn_lib.ModeKeys.EVAL:
516f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower      eval_metric_ops = _make_logistic_eval_metric_ops(
526f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower          labels=labels,
536f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower          predictions=predictions,
546f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower          thresholds=thresholds)
556f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower    else:
566f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower      eval_metric_ops = None
576f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower    return model_fn_lib.ModelFnOps(
586f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower        mode=mode,
596f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower        predictions=predictions,
606f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower        loss=loss,
616f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower        train_op=train_op,
623f6e5288c8b9d367bf1939f0a89bfb7568e1f731A. Unique TensorFlower        eval_metric_ops=eval_metric_ops,
633f6e5288c8b9d367bf1939f0a89bfb7568e1f731A. Unique TensorFlower        output_alternatives={
643f6e5288c8b9d367bf1939f0a89bfb7568e1f731A. Unique TensorFlower            'head': (constants.ProblemType.LOGISTIC_REGRESSION, {
653f6e5288c8b9d367bf1939f0a89bfb7568e1f731A. Unique TensorFlower                'predictions': predictions
663f6e5288c8b9d367bf1939f0a89bfb7568e1f731A. Unique TensorFlower            })
673f6e5288c8b9d367bf1939f0a89bfb7568e1f731A. Unique TensorFlower        })
68b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho
696f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower  return _model_fn
70b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho
71b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho
726f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower# TODO(roumposg): Deprecate and delete after converting users to use head.
736f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlowerdef LogisticRegressor(  # pylint: disable=invalid-name
746f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower    model_fn, thresholds=None, model_dir=None, config=None,
756f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower    feature_engineering_fn=None):
766f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower  """Builds a logistic regression Estimator for binary classification.
77641894039e740840816575103dc816541381da59A. Unique TensorFlower
786f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower  This method provides a basic Estimator with some additional metrics for custom
79641894039e740840816575103dc816541381da59A. Unique TensorFlower  binary classification models, including AUC, precision/recall and accuracy.
80641894039e740840816575103dc816541381da59A. Unique TensorFlower
81641894039e740840816575103dc816541381da59A. Unique TensorFlower  Example:
82641894039e740840816575103dc816541381da59A. Unique TensorFlower
83641894039e740840816575103dc816541381da59A. Unique TensorFlower  ```python
84641894039e740840816575103dc816541381da59A. Unique TensorFlower    # See tf.contrib.learn.Estimator(...) for details on model_fn structure
85641894039e740840816575103dc816541381da59A. Unique TensorFlower    def my_model_fn(...):
86641894039e740840816575103dc816541381da59A. Unique TensorFlower      pass
87641894039e740840816575103dc816541381da59A. Unique TensorFlower
88641894039e740840816575103dc816541381da59A. Unique TensorFlower    estimator = LogisticRegressor(model_fn=my_model_fn)
89641894039e740840816575103dc816541381da59A. Unique TensorFlower
90641894039e740840816575103dc816541381da59A. Unique TensorFlower    # Input builders
91641894039e740840816575103dc816541381da59A. Unique TensorFlower    def input_fn_train:
92641894039e740840816575103dc816541381da59A. Unique TensorFlower      pass
93641894039e740840816575103dc816541381da59A. Unique TensorFlower
94641894039e740840816575103dc816541381da59A. Unique TensorFlower    estimator.fit(input_fn=input_fn_train)
95641894039e740840816575103dc816541381da59A. Unique TensorFlower    estimator.predict(x=x)
96641894039e740840816575103dc816541381da59A. Unique TensorFlower  ```
976f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower
986f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower  Args:
996f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower    model_fn: Model function with the signature:
1006f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower      `(features, labels, mode) -> (predictions, loss, train_op)`.
1016f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower      Expects the returned predictions to be probabilities in [0.0, 1.0].
1026f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower    thresholds: List of floating point thresholds to use for accuracy,
1036f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower      precision, and recall metrics. If `None`, defaults to `[0.5]`.
1046f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower    model_dir: Directory to save model parameters, graphs, etc. This can also
1056f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower      be used to load checkpoints from the directory into a estimator to
1066f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower      continue training a previously saved model.
1076f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower    config: A RunConfig configuration object.
1086f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower    feature_engineering_fn: Feature engineering function. Takes features and
1096f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower                      labels which are the output of `input_fn` and
1106f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower                      returns features and labels which will be fed
1116f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower                      into the model.
1126f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower
1136f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower  Returns:
1146161574ed26e053e3bf08a56ac161ce42985e12cTaehoon Lee    An `Estimator` instance.
115b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho  """
1166f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower  return estimator.Estimator(
1176f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower      model_fn=_get_model_fn_with_logistic_metrics(model_fn),
1186f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower      model_dir=model_dir,
1196f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower      config=config,
1206f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower      params={'thresholds': thresholds},
1216f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower      feature_engineering_fn=feature_engineering_fn)
1226f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower
123b22e18fc71d3301e4ad91f1a232aefc3c8de7ab1Wei Ho
1246f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlowerdef _make_logistic_eval_metric_ops(labels, predictions, thresholds):
1256f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower  """Returns a dictionary of evaluation metric ops for logistic regression.
1266f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower
1276f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower  Args:
1286f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower    labels: The labels `Tensor`, or a dict with only one `Tensor` keyed by name.
1296f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower    predictions: The predictions `Tensor`.
1306f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower    thresholds: List of floating point thresholds to use for accuracy,
1316f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower      precision, and recall metrics.
1326f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower
1336f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower  Returns:
1346f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower    A dict of metric results keyed by name.
1356f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower  """
1366f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower  # If labels is a dict with a single key, unpack into a single tensor.
1376f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower  labels_tensor = labels
1386f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower  if isinstance(labels, dict) and len(labels) == 1:
1396f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower    labels_tensor = labels.values()[0]
1406f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower
1416f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower  metrics = {}
1426f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower  metrics[metric_key.MetricKey.PREDICTION_MEAN] = metrics_lib.streaming_mean(
1436f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower      predictions)
1446f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower  metrics[metric_key.MetricKey.LABEL_MEAN] = metrics_lib.streaming_mean(
1456f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower      labels_tensor)
1466f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower  # Also include the streaming mean of the label as an accuracy baseline, as
1476f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower  # a reminder to users.
1486f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower  metrics[metric_key.MetricKey.ACCURACY_BASELINE] = metrics_lib.streaming_mean(
1496f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower      labels_tensor)
1506f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower
1516f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower  metrics[metric_key.MetricKey.AUC] = metrics_lib.streaming_auc(
1526f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower      labels=labels_tensor, predictions=predictions)
1536f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower
1546f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower  for threshold in thresholds:
1556f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower    predictions_at_threshold = math_ops.to_float(
1566f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower        math_ops.greater_equal(predictions, threshold),
1576f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower        name='predictions_at_threshold_%f' % threshold)
1586f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower    metrics[metric_key.MetricKey.ACCURACY_MEAN % threshold] = (
1596f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower        metrics_lib.streaming_accuracy(labels=labels_tensor,
1606f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower                                       predictions=predictions_at_threshold))
1616f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower    # Precision for positive examples.
1626f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower    metrics[metric_key.MetricKey.PRECISION_MEAN % threshold] = (
1636f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower        metrics_lib.streaming_precision(labels=labels_tensor,
1646f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower                                        predictions=predictions_at_threshold))
1656f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower    # Recall for positive examples.
1666f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower    metrics[metric_key.MetricKey.RECALL_MEAN % threshold] = (
1676f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower        metrics_lib.streaming_recall(labels=labels_tensor,
1686f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower                                     predictions=predictions_at_threshold))
1696f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower
1706f6d4af798cd77d0299f3880a1343558ccb03bcfA. Unique TensorFlower  return metrics
171