1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Base class for creating split nodes using one or more features.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import abc 22from tensorflow.contrib.boosted_trees.python.ops import batch_ops_utils 23from tensorflow.python.ops import control_flow_ops 24 25 26class BaseSplitHandler(object): 27 """Abstract Base class defining split handlers interface.""" 28 29 __metaclass__ = abc.ABCMeta 30 31 def __init__(self, 32 l1_regularization, 33 l2_regularization, 34 tree_complexity_regularization, 35 min_node_weight, 36 feature_column_group_id, 37 gradient_shape, 38 hessian_shape, 39 multiclass_strategy, 40 name=None): 41 """Constructor for BaseSplitHandler. 42 43 Args: 44 l1_regularization: L1 regularization applied for this split handler. 45 l2_regularization: L2 regularization applied for this split handler. 46 tree_complexity_regularization: Tree complexity regularization applied 47 for this split handler. 48 min_node_weight: Minimum sum of weights of examples in each partition to 49 be considered for splitting. 50 feature_column_group_id: Feature column group index. 51 gradient_shape: A TensorShape, containing shape of gradients. 52 hessian_shape: A TensorShape, containing shape of hessians. 53 multiclass_strategy: Strategy describing how to treat multiclass problems. 54 name: An optional handler name. 55 """ 56 self._l1_regularization = l1_regularization 57 self._l2_regularization = l2_regularization 58 self._tree_complexity_regularization = tree_complexity_regularization 59 self._min_node_weight = min_node_weight 60 self._feature_column_group_id = feature_column_group_id 61 self._name = name or "" 62 self._multiclass_strategy = multiclass_strategy 63 self._hessian_shape = hessian_shape 64 self._gradient_shape = gradient_shape 65 66 def scheduled_reads(self): 67 """Returns the list of `ScheduledOp`s required for update_stats.""" 68 return [] 69 70 @abc.abstractmethod 71 def update_stats(self, stamp_token, example_partition_ids, gradients, 72 hessians, empty_gradients, empty_hessians, weights, 73 is_active, scheduled_reads): 74 """Updates the state for this split handler. 75 76 Args: 77 stamp_token: An int32 scalar tensor containing the current stamp token. 78 example_partition_ids: A dense tensor, containing an int32 for each 79 example which is the partition id that the example ends up in. 80 gradients: A dense tensor of gradients. 81 hessians: A dense tensor of hessians. 82 empty_gradients: A dense empty tensor of the same shape (for dimensions > 83 0) as gradients. 84 empty_hessians: A dense empty tensor of the same shape (for dimensions > 85 0) as hessians. 86 weights: A dense float32 tensor with a weight for each example. 87 is_active: A boolean tensor that says if this handler is active or not. 88 One value for the current layer and one value for the next layer. 89 scheduled_reads: List of results from the scheduled reads. 90 91 Returns: 92 A tuple of the op that updates the stats for this handler and a list of 93 `ScheduledOp`s. 94 """ 95 96 def update_stats_sync(self, stamp_token, example_partition_ids, gradients, 97 hessians, empty_gradients, empty_hessians, weights, 98 is_active): 99 """Updates the state for this split handler running the scheduled I/O. 100 101 Args: 102 stamp_token: An int32 scalar tensor containing the current stamp token. 103 example_partition_ids: A dense tensor, containing an int32 for each 104 example which is the partition id that the example ends up in. 105 gradients: A dense tensor of gradients. 106 hessians: A dense tensor of hessians. 107 empty_gradients: A dense empty tensor of the same shape (for dimensions > 108 0) as gradients. 109 empty_hessians: A dense empty tensor of the same shape (for dimensions > 110 0) as hessians. 111 weights: A dense float32 tensor with a weight for each example. 112 is_active: A boolean tensor that says if this handler is active or not. 113 One value for the current layer and one value for the next layer. 114 115 Returns: 116 Op that updates the stats for this handler. 117 """ 118 handler_reads = {self: self.scheduled_reads()} 119 handler_results = batch_ops_utils.run_handler_scheduled_ops( 120 handler_reads, stamp_token, None) 121 update_1, scheduled_updates = self.update_stats( 122 stamp_token, example_partition_ids, gradients, hessians, 123 empty_gradients, empty_hessians, weights, is_active, 124 handler_results[self]) 125 update_2 = batch_ops_utils.run_handler_scheduled_ops({ 126 self: scheduled_updates 127 }, stamp_token, None) 128 return control_flow_ops.group(update_1, *update_2[self]) 129 130 @abc.abstractmethod 131 def make_splits(self, stamp_token, next_stamp_token, class_id): 132 """Create the best split using the accumulated stats and flush the state. 133 134 This should only be called by the master. 135 136 Args: 137 stamp_token: An int32 scalar tensor containing the current stamp token. 138 next_stamp_token: An int32 scalar tensor containing the stamp token for 139 the next iteration. 140 class_id: what class id the handler gathers stats for (for tree per class 141 strategy). When class_id=-1, the strategy is not tree per class. 142 Returns: 143 A tuple (are_splits_ready, partition_id, gain, split_info) where 144 are_splits_ready is a scalar boolean tensor, partition_id is a rank 1, 145 int32 tensor, gain is a rank 1 float32 tensor and split_info is a rank 1 146 string tensor containing serialized SplitInfo protos. 147 """ 148