1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Strategy to export custom proto formats.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import os 23 24from tensorflow.contrib.boosted_trees.proto import tree_config_pb2 25from tensorflow.contrib.boosted_trees.python.training.functions import gbdt_batch 26from tensorflow.contrib.decision_trees.proto import generic_tree_model_extensions_pb2 27from tensorflow.contrib.decision_trees.proto import generic_tree_model_pb2 28from tensorflow.contrib.learn.python.learn import export_strategy 29from tensorflow.contrib.learn.python.learn.utils import saved_model_export_utils 30from tensorflow.python.client import session as tf_session 31from tensorflow.python.framework import ops 32from tensorflow.python.platform import gfile 33from tensorflow.python.saved_model import loader as saved_model_loader 34from tensorflow.python.saved_model import tag_constants 35 36_SPARSE_FLOAT_FEATURE_NAME_TEMPLATE = "%s_%d" 37 38 39def make_custom_export_strategy(name, 40 convert_fn, 41 feature_columns, 42 export_input_fn): 43 """Makes custom exporter of GTFlow tree format. 44 45 Args: 46 name: A string, for the name of the export strategy. 47 convert_fn: A function that converts the tree proto to desired format and 48 saves it to the desired location. Can be None to skip conversion. 49 feature_columns: A list of feature columns. 50 export_input_fn: A function that takes no arguments and returns an 51 `InputFnOps`. 52 53 Returns: 54 An `ExportStrategy`. 55 """ 56 base_strategy = saved_model_export_utils.make_export_strategy( 57 serving_input_fn=export_input_fn) 58 input_fn = export_input_fn() 59 (sorted_feature_names, dense_floats, sparse_float_indices, _, _, 60 sparse_int_indices, _, _) = gbdt_batch.extract_features( 61 input_fn.features, feature_columns) 62 63 def export_fn(estimator, export_dir, checkpoint_path=None, eval_result=None): 64 """A wrapper to export to SavedModel, and convert it to other formats.""" 65 result_dir = base_strategy.export(estimator, export_dir, 66 checkpoint_path, 67 eval_result) 68 with ops.Graph().as_default() as graph: 69 with tf_session.Session(graph=graph) as sess: 70 saved_model_loader.load( 71 sess, [tag_constants.SERVING], result_dir) 72 # Note: This is GTFlow internal API and might change. 73 ensemble_model = graph.get_operation_by_name( 74 "ensemble_model/TreeEnsembleSerialize") 75 _, dfec_str = sess.run(ensemble_model.outputs) 76 dtec = tree_config_pb2.DecisionTreeEnsembleConfig() 77 dtec.ParseFromString(dfec_str) 78 # Export the result in the same folder as the saved model. 79 if convert_fn: 80 convert_fn(dtec, sorted_feature_names, 81 len(dense_floats), 82 len(sparse_float_indices), 83 len(sparse_int_indices), result_dir, eval_result) 84 feature_importances = _get_feature_importances( 85 dtec, sorted_feature_names, 86 len(dense_floats), 87 len(sparse_float_indices), len(sparse_int_indices)) 88 sorted_by_importance = sorted( 89 feature_importances.items(), key=lambda x: -x[1]) 90 assets_dir = os.path.join(result_dir, "assets.extra") 91 gfile.MakeDirs(assets_dir) 92 with gfile.GFile(os.path.join(assets_dir, "feature_importances"), 93 "w") as f: 94 f.write("\n".join("%s, %f" % (k, v) for k, v in sorted_by_importance)) 95 return result_dir 96 return export_strategy.ExportStrategy(name, export_fn) 97 98 99def convert_to_universal_format(dtec, sorted_feature_names, 100 num_dense, num_sparse_float, 101 num_sparse_int, 102 feature_name_to_proto=None): 103 """Convert GTFlow trees to universal format.""" 104 del num_sparse_int # unused. 105 model_and_features = generic_tree_model_pb2.ModelAndFeatures() 106 # TODO(jonasz): Feature descriptions should contain information about how each 107 # feature is processed before it's fed to the model (e.g. bucketing 108 # information). As of now, this serves as a list of features the model uses. 109 for feature_name in sorted_feature_names: 110 if not feature_name_to_proto: 111 model_and_features.features[feature_name].SetInParent() 112 else: 113 model_and_features.features[feature_name].CopyFrom( 114 feature_name_to_proto[feature_name]) 115 model = model_and_features.model 116 model.ensemble.summation_combination_technique.SetInParent() 117 for tree_idx in range(len(dtec.trees)): 118 gtflow_tree = dtec.trees[tree_idx] 119 tree_weight = dtec.tree_weights[tree_idx] 120 member = model.ensemble.members.add() 121 member.submodel_id.value = tree_idx 122 tree = member.submodel.decision_tree 123 for node_idx in range(len(gtflow_tree.nodes)): 124 gtflow_node = gtflow_tree.nodes[node_idx] 125 node = tree.nodes.add() 126 node_type = gtflow_node.WhichOneof("node") 127 node.node_id.value = node_idx 128 if node_type == "leaf": 129 leaf = gtflow_node.leaf 130 if leaf.HasField("vector"): 131 for weight in leaf.vector.value: 132 new_value = node.leaf.vector.value.add() 133 new_value.float_value = weight * tree_weight 134 else: 135 for index, weight in zip( 136 leaf.sparse_vector.index, leaf.sparse_vector.value): 137 new_value = node.leaf.sparse_vector.sparse_value[index] 138 new_value.float_value = weight * tree_weight 139 else: 140 node = node.binary_node 141 # Binary nodes here. 142 if node_type == "dense_float_binary_split": 143 split = gtflow_node.dense_float_binary_split 144 feature_id = split.feature_column 145 inequality_test = node.inequality_left_child_test 146 inequality_test.feature_id.id.value = sorted_feature_names[feature_id] 147 inequality_test.type = ( 148 generic_tree_model_pb2.InequalityTest.LESS_OR_EQUAL) 149 inequality_test.threshold.float_value = split.threshold 150 elif node_type == "sparse_float_binary_split_default_left": 151 split = gtflow_node.sparse_float_binary_split_default_left.split 152 node.default_direction = (generic_tree_model_pb2.BinaryNode.LEFT) 153 feature_id = split.feature_column + num_dense 154 inequality_test = node.inequality_left_child_test 155 inequality_test.feature_id.id.value = ( 156 _SPARSE_FLOAT_FEATURE_NAME_TEMPLATE % 157 (sorted_feature_names[feature_id], split.dimension_id)) 158 model_and_features.features.pop(sorted_feature_names[feature_id]) 159 (model_and_features.features[inequality_test.feature_id.id.value] 160 .SetInParent()) 161 inequality_test.type = ( 162 generic_tree_model_pb2.InequalityTest.LESS_OR_EQUAL) 163 inequality_test.threshold.float_value = split.threshold 164 elif node_type == "sparse_float_binary_split_default_right": 165 split = gtflow_node.sparse_float_binary_split_default_right.split 166 node.default_direction = ( 167 generic_tree_model_pb2.BinaryNode.RIGHT) 168 # TODO(nponomareva): adjust this id assignement when we allow multi- 169 # column sparse tensors. 170 feature_id = split.feature_column + num_dense 171 inequality_test = node.inequality_left_child_test 172 inequality_test.feature_id.id.value = ( 173 _SPARSE_FLOAT_FEATURE_NAME_TEMPLATE % 174 (sorted_feature_names[feature_id], split.dimension_id)) 175 model_and_features.features.pop(sorted_feature_names[feature_id]) 176 (model_and_features.features[inequality_test.feature_id.id.value] 177 .SetInParent()) 178 inequality_test.type = ( 179 generic_tree_model_pb2.InequalityTest.LESS_OR_EQUAL) 180 inequality_test.threshold.float_value = split.threshold 181 elif node_type == "categorical_id_binary_split": 182 split = gtflow_node.categorical_id_binary_split 183 node.default_direction = generic_tree_model_pb2.BinaryNode.RIGHT 184 feature_id = split.feature_column + num_dense + num_sparse_float 185 categorical_test = ( 186 generic_tree_model_extensions_pb2.MatchingValuesTest()) 187 categorical_test.feature_id.id.value = sorted_feature_names[ 188 feature_id] 189 matching_id = categorical_test.value.add() 190 matching_id.int64_value = split.feature_id 191 node.custom_left_child_test.Pack(categorical_test) 192 else: 193 raise ValueError("Unexpected node type %s", node_type) 194 node.left_child_id.value = split.left_id 195 node.right_child_id.value = split.right_id 196 return model_and_features 197 198 199def _get_feature_importances(dtec, feature_names, num_dense_floats, 200 num_sparse_float, num_sparse_int): 201 """Export the feature importance per feature column.""" 202 del num_sparse_int # Unused. 203 sums = collections.defaultdict(lambda: 0) 204 for tree_idx in range(len(dtec.trees)): 205 tree = dtec.trees[tree_idx] 206 for tree_node in tree.nodes: 207 node_type = tree_node.WhichOneof("node") 208 if node_type == "dense_float_binary_split": 209 split = tree_node.dense_float_binary_split 210 split_column = feature_names[split.feature_column] 211 elif node_type == "sparse_float_binary_split_default_left": 212 split = tree_node.sparse_float_binary_split_default_left.split 213 split_column = _SPARSE_FLOAT_FEATURE_NAME_TEMPLATE % ( 214 feature_names[split.feature_column + num_dense_floats], 215 split.dimension_id) 216 elif node_type == "sparse_float_binary_split_default_right": 217 split = tree_node.sparse_float_binary_split_default_right.split 218 split_column = _SPARSE_FLOAT_FEATURE_NAME_TEMPLATE % ( 219 feature_names[split.feature_column + num_dense_floats], 220 split.dimension_id) 221 elif node_type == "categorical_id_binary_split": 222 split = tree_node.categorical_id_binary_split 223 split_column = feature_names[split.feature_column + num_dense_floats + 224 num_sparse_float] 225 elif node_type == "categorical_id_set_membership_binary_split": 226 split = tree_node.categorical_id_set_membership_binary_split 227 split_column = feature_names[split.feature_column + num_dense_floats + 228 num_sparse_float] 229 elif node_type == "leaf": 230 assert tree_node.node_metadata.gain == 0 231 continue 232 else: 233 raise ValueError("Unexpected split type %s", node_type) 234 # Apply shrinkage factor. It is important since it is not always uniform 235 # across different trees. 236 sums[split_column] += ( 237 tree_node.node_metadata.gain * dtec.tree_weights[tree_idx]) 238 return dict(sums) 239