1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Base class for creating split nodes using one or more features."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import abc
22from tensorflow.contrib.boosted_trees.python.ops import batch_ops_utils
23from tensorflow.python.ops import control_flow_ops
24
25
26class BaseSplitHandler(object):
27  """Abstract Base class defining split handlers interface."""
28
29  __metaclass__ = abc.ABCMeta
30
31  def __init__(self,
32               l1_regularization,
33               l2_regularization,
34               tree_complexity_regularization,
35               min_node_weight,
36               feature_column_group_id,
37               gradient_shape,
38               hessian_shape,
39               multiclass_strategy,
40               name=None):
41    """Constructor for BaseSplitHandler.
42
43    Args:
44      l1_regularization: L1 regularization applied for this split handler.
45      l2_regularization: L2 regularization applied for this split handler.
46      tree_complexity_regularization: Tree complexity regularization applied
47          for this split handler.
48      min_node_weight: Minimum sum of weights of examples in each partition to
49          be considered for splitting.
50      feature_column_group_id: Feature column group index.
51      gradient_shape: A TensorShape, containing shape of gradients.
52      hessian_shape: A TensorShape, containing shape of hessians.
53      multiclass_strategy: Strategy describing how to treat multiclass problems.
54      name: An optional handler name.
55    """
56    self._l1_regularization = l1_regularization
57    self._l2_regularization = l2_regularization
58    self._tree_complexity_regularization = tree_complexity_regularization
59    self._min_node_weight = min_node_weight
60    self._feature_column_group_id = feature_column_group_id
61    self._name = name or ""
62    self._multiclass_strategy = multiclass_strategy
63    self._hessian_shape = hessian_shape
64    self._gradient_shape = gradient_shape
65
66  def scheduled_reads(self):
67    """Returns the list of `ScheduledOp`s required for update_stats."""
68    return []
69
70  @abc.abstractmethod
71  def update_stats(self, stamp_token, example_partition_ids, gradients,
72                   hessians, empty_gradients, empty_hessians, weights,
73                   is_active, scheduled_reads):
74    """Updates the state for this split handler.
75
76    Args:
77      stamp_token: An int32 scalar tensor containing the current stamp token.
78      example_partition_ids: A dense tensor, containing an int32 for each
79        example which is the partition id that the example ends up in.
80      gradients: A dense tensor of gradients.
81      hessians: A dense tensor of hessians.
82      empty_gradients: A dense empty tensor of the same shape (for dimensions >
83        0) as gradients.
84      empty_hessians: A dense empty tensor of the same shape (for dimensions >
85        0) as hessians.
86      weights: A dense float32 tensor with a weight for each example.
87      is_active: A boolean tensor that says if this handler is active or not.
88          One value for the current layer and one value for the next layer.
89      scheduled_reads: List of results from the scheduled reads.
90
91    Returns:
92      A tuple of the op that updates the stats for this handler and a list of
93      `ScheduledOp`s.
94    """
95
96  def update_stats_sync(self, stamp_token, example_partition_ids, gradients,
97                        hessians, empty_gradients, empty_hessians, weights,
98                        is_active):
99    """Updates the state for this split handler running the scheduled I/O.
100
101    Args:
102      stamp_token: An int32 scalar tensor containing the current stamp token.
103      example_partition_ids: A dense tensor, containing an int32 for each
104        example which is the partition id that the example ends up in.
105      gradients: A dense tensor of gradients.
106      hessians: A dense tensor of hessians.
107      empty_gradients: A dense empty tensor of the same shape (for dimensions >
108        0) as gradients.
109      empty_hessians: A dense empty tensor of the same shape (for dimensions >
110        0) as hessians.
111      weights: A dense float32 tensor with a weight for each example.
112      is_active: A boolean tensor that says if this handler is active or not.
113          One value for the current layer and one value for the next layer.
114
115    Returns:
116      Op that updates the stats for this handler.
117    """
118    handler_reads = {self: self.scheduled_reads()}
119    handler_results = batch_ops_utils.run_handler_scheduled_ops(
120        handler_reads, stamp_token, None)
121    update_1, scheduled_updates = self.update_stats(
122        stamp_token, example_partition_ids, gradients, hessians,
123        empty_gradients, empty_hessians, weights, is_active,
124        handler_results[self])
125    update_2 = batch_ops_utils.run_handler_scheduled_ops({
126        self: scheduled_updates
127    }, stamp_token, None)
128    return control_flow_ops.group(update_1, *update_2[self])
129
130  @abc.abstractmethod
131  def make_splits(self, stamp_token, next_stamp_token, class_id):
132    """Create the best split using the accumulated stats and flush the state.
133
134    This should only be called by the master.
135
136    Args:
137      stamp_token: An int32 scalar tensor containing the current stamp token.
138      next_stamp_token: An int32 scalar tensor containing the stamp token for
139        the next iteration.
140      class_id: what class id the handler gathers stats for (for tree per class
141        strategy). When class_id=-1, the strategy is not tree per class.
142    Returns:
143      A tuple (are_splits_ready, partition_id, gain, split_info) where
144      are_splits_ready is a scalar boolean tensor, partition_id is a rank 1,
145      int32 tensor, gain is a rank 1 float32 tensor and split_info is a rank 1
146      string tensor containing serialized SplitInfo protos.
147    """
148