1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15r"""Demonstrates multiclass MNIST TF Boosted trees example.
16
17  This example demonstrates how to run experiments with TF Boosted Trees on
18  a MNIST dataset. We are using layer by layer boosting with diagonal hessian
19  strategy for multiclass handling, and cross entropy loss.
20
21  Example Usage:
22  python tensorflow/contrib/boosted_trees/examples/mnist.py \
23  --output_dir="/tmp/mnist" --depth=4 --learning_rate=0.3 --batch_size=60000  \
24  --examples_per_layer=60000 --eval_batch_size=10000 --num_eval_steps=1 \
25  --num_trees=10 --l2=1 --vmodule=training_ops=1
26
27  When training is done, accuracy on eval data is reported. Point tensorboard
28  to the directory for the run to see how the training progresses:
29
30  tensorboard --logdir=/tmp/mnist
31
32"""
33from __future__ import absolute_import
34from __future__ import division
35from __future__ import print_function
36
37import argparse
38import sys
39
40import numpy as np
41import tensorflow as tf
42from tensorflow.contrib.boosted_trees.estimator_batch.estimator import GradientBoostedDecisionTreeClassifier
43from tensorflow.contrib.boosted_trees.proto import learner_pb2
44from tensorflow.contrib.learn import learn_runner
45
46
47def get_input_fn(dataset_split,
48                 batch_size,
49                 capacity=10000,
50                 min_after_dequeue=3000):
51  """Input function over MNIST data."""
52
53  def _input_fn():
54    """Prepare features and labels."""
55    images_batch, labels_batch = tf.train.shuffle_batch(
56        tensors=[dataset_split.images,
57                 dataset_split.labels.astype(np.int32)],
58        batch_size=batch_size,
59        capacity=capacity,
60        min_after_dequeue=min_after_dequeue,
61        enqueue_many=True,
62        num_threads=4)
63    features_map = {"images": images_batch}
64    return features_map, labels_batch
65
66  return _input_fn
67
68
69# Main config - creates a TF Boosted Trees Estimator based on flags.
70def _get_tfbt(output_dir):
71  """Configures TF Boosted Trees estimator based on flags."""
72  learner_config = learner_pb2.LearnerConfig()
73
74  num_classes = 10
75
76  learner_config.learning_rate_tuner.fixed.learning_rate = FLAGS.learning_rate
77  learner_config.num_classes = num_classes
78  learner_config.regularization.l1 = 0.0
79  learner_config.regularization.l2 = FLAGS.l2 / FLAGS.examples_per_layer
80  learner_config.constraints.max_tree_depth = FLAGS.depth
81
82  growing_mode = learner_pb2.LearnerConfig.LAYER_BY_LAYER
83  learner_config.growing_mode = growing_mode
84  run_config = tf.contrib.learn.RunConfig(save_checkpoints_secs=300)
85
86  learner_config.multi_class_strategy = (
87      learner_pb2.LearnerConfig.DIAGONAL_HESSIAN)
88
89  # Create a TF Boosted trees estimator that can take in custom loss.
90  estimator = GradientBoostedDecisionTreeClassifier(
91      learner_config=learner_config,
92      n_classes=num_classes,
93      examples_per_layer=FLAGS.examples_per_layer,
94      model_dir=output_dir,
95      num_trees=FLAGS.num_trees,
96      center_bias=False,
97      config=run_config)
98  return estimator
99
100
101def _make_experiment_fn(output_dir):
102  """Creates experiment for gradient boosted decision trees."""
103  data = tf.contrib.learn.datasets.mnist.load_mnist()
104  train_input_fn = get_input_fn(data.train, FLAGS.batch_size)
105  eval_input_fn = get_input_fn(data.validation, FLAGS.eval_batch_size)
106
107  return tf.contrib.learn.Experiment(
108      estimator=_get_tfbt(output_dir),
109      train_input_fn=train_input_fn,
110      eval_input_fn=eval_input_fn,
111      train_steps=None,
112      eval_steps=FLAGS.num_eval_steps,
113      eval_metrics=None)
114
115
116def main(unused_argv):
117  learn_runner.run(
118      experiment_fn=_make_experiment_fn,
119      output_dir=FLAGS.output_dir,
120      schedule="train_and_evaluate")
121
122
123if __name__ == "__main__":
124  tf.logging.set_verbosity(tf.logging.INFO)
125  parser = argparse.ArgumentParser()
126  # Define the list of flags that users can change.
127  parser.add_argument(
128      "--output_dir",
129      type=str,
130      required=True,
131      help="Choose the dir for the output.")
132  parser.add_argument(
133      "--batch_size",
134      type=int,
135      default=1000,
136      help="The batch size for reading data.")
137  parser.add_argument(
138      "--eval_batch_size",
139      type=int,
140      default=1000,
141      help="Size of the batch for eval.")
142  parser.add_argument(
143      "--num_eval_steps",
144      type=int,
145      default=1,
146      help="The number of steps to run evaluation for.")
147  # Flags for gradient boosted trees config.
148  parser.add_argument(
149      "--depth", type=int, default=4, help="Maximum depth of weak learners.")
150  parser.add_argument(
151      "--l2", type=float, default=1.0, help="l2 regularization per batch.")
152  parser.add_argument(
153      "--learning_rate",
154      type=float,
155      default=0.1,
156      help="Learning rate (shrinkage weight) with which each new tree is added."
157  )
158  parser.add_argument(
159      "--examples_per_layer",
160      type=int,
161      default=1000,
162      help="Number of examples to accumulate stats for per layer.")
163  parser.add_argument(
164      "--num_trees",
165      type=int,
166      default=None,
167      required=True,
168      help="Number of trees to grow before stopping.")
169
170  FLAGS, unparsed = parser.parse_known_args()
171  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
172