1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Strategy to export custom proto formats."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import os
23
24from tensorflow.contrib.boosted_trees.proto import tree_config_pb2
25from tensorflow.contrib.boosted_trees.python.training.functions import gbdt_batch
26from tensorflow.contrib.decision_trees.proto import generic_tree_model_extensions_pb2
27from tensorflow.contrib.decision_trees.proto import generic_tree_model_pb2
28from tensorflow.contrib.learn.python.learn import export_strategy
29from tensorflow.contrib.learn.python.learn.utils import saved_model_export_utils
30from tensorflow.python.client import session as tf_session
31from tensorflow.python.framework import ops
32from tensorflow.python.platform import gfile
33from tensorflow.python.saved_model import loader as saved_model_loader
34from tensorflow.python.saved_model import tag_constants
35
36_SPARSE_FLOAT_FEATURE_NAME_TEMPLATE = "%s_%d"
37
38
39def make_custom_export_strategy(name,
40                                convert_fn,
41                                feature_columns,
42                                export_input_fn):
43  """Makes custom exporter of GTFlow tree format.
44
45  Args:
46    name: A string, for the name of the export strategy.
47    convert_fn: A function that converts the tree proto to desired format and
48      saves it to the desired location. Can be None to skip conversion.
49    feature_columns: A list of feature columns.
50    export_input_fn: A function that takes no arguments and returns an
51      `InputFnOps`.
52
53  Returns:
54    An `ExportStrategy`.
55  """
56  base_strategy = saved_model_export_utils.make_export_strategy(
57      serving_input_fn=export_input_fn)
58  input_fn = export_input_fn()
59  (sorted_feature_names, dense_floats, sparse_float_indices, _, _,
60   sparse_int_indices, _, _) = gbdt_batch.extract_features(
61       input_fn.features, feature_columns)
62
63  def export_fn(estimator, export_dir, checkpoint_path=None, eval_result=None):
64    """A wrapper to export to SavedModel, and convert it to other formats."""
65    result_dir = base_strategy.export(estimator, export_dir,
66                                      checkpoint_path,
67                                      eval_result)
68    with ops.Graph().as_default() as graph:
69      with tf_session.Session(graph=graph) as sess:
70        saved_model_loader.load(
71            sess, [tag_constants.SERVING], result_dir)
72        # Note: This is GTFlow internal API and might change.
73        ensemble_model = graph.get_operation_by_name(
74            "ensemble_model/TreeEnsembleSerialize")
75        _, dfec_str = sess.run(ensemble_model.outputs)
76        dtec = tree_config_pb2.DecisionTreeEnsembleConfig()
77        dtec.ParseFromString(dfec_str)
78        # Export the result in the same folder as the saved model.
79        if convert_fn:
80          convert_fn(dtec, sorted_feature_names,
81                     len(dense_floats),
82                     len(sparse_float_indices),
83                     len(sparse_int_indices), result_dir, eval_result)
84        feature_importances = _get_feature_importances(
85            dtec, sorted_feature_names,
86            len(dense_floats),
87            len(sparse_float_indices), len(sparse_int_indices))
88        sorted_by_importance = sorted(
89            feature_importances.items(), key=lambda x: -x[1])
90        assets_dir = os.path.join(result_dir, "assets.extra")
91        gfile.MakeDirs(assets_dir)
92        with gfile.GFile(os.path.join(assets_dir, "feature_importances"),
93                         "w") as f:
94          f.write("\n".join("%s, %f" % (k, v) for k, v in sorted_by_importance))
95    return result_dir
96  return export_strategy.ExportStrategy(name, export_fn)
97
98
99def convert_to_universal_format(dtec, sorted_feature_names,
100                                num_dense, num_sparse_float,
101                                num_sparse_int,
102                                feature_name_to_proto=None):
103  """Convert GTFlow trees to universal format."""
104  del num_sparse_int  # unused.
105  model_and_features = generic_tree_model_pb2.ModelAndFeatures()
106  # TODO(jonasz): Feature descriptions should contain information about how each
107  # feature is processed before it's fed to the model (e.g. bucketing
108  # information). As of now, this serves as a list of features the model uses.
109  for feature_name in sorted_feature_names:
110    if not feature_name_to_proto:
111      model_and_features.features[feature_name].SetInParent()
112    else:
113      model_and_features.features[feature_name].CopyFrom(
114          feature_name_to_proto[feature_name])
115  model = model_and_features.model
116  model.ensemble.summation_combination_technique.SetInParent()
117  for tree_idx in range(len(dtec.trees)):
118    gtflow_tree = dtec.trees[tree_idx]
119    tree_weight = dtec.tree_weights[tree_idx]
120    member = model.ensemble.members.add()
121    member.submodel_id.value = tree_idx
122    tree = member.submodel.decision_tree
123    for node_idx in range(len(gtflow_tree.nodes)):
124      gtflow_node = gtflow_tree.nodes[node_idx]
125      node = tree.nodes.add()
126      node_type = gtflow_node.WhichOneof("node")
127      node.node_id.value = node_idx
128      if node_type == "leaf":
129        leaf = gtflow_node.leaf
130        if leaf.HasField("vector"):
131          for weight in leaf.vector.value:
132            new_value = node.leaf.vector.value.add()
133            new_value.float_value = weight * tree_weight
134        else:
135          for index, weight in zip(
136              leaf.sparse_vector.index, leaf.sparse_vector.value):
137            new_value = node.leaf.sparse_vector.sparse_value[index]
138            new_value.float_value = weight * tree_weight
139      else:
140        node = node.binary_node
141        # Binary nodes here.
142        if node_type == "dense_float_binary_split":
143          split = gtflow_node.dense_float_binary_split
144          feature_id = split.feature_column
145          inequality_test = node.inequality_left_child_test
146          inequality_test.feature_id.id.value = sorted_feature_names[feature_id]
147          inequality_test.type = (
148              generic_tree_model_pb2.InequalityTest.LESS_OR_EQUAL)
149          inequality_test.threshold.float_value = split.threshold
150        elif node_type == "sparse_float_binary_split_default_left":
151          split = gtflow_node.sparse_float_binary_split_default_left.split
152          node.default_direction = (generic_tree_model_pb2.BinaryNode.LEFT)
153          feature_id = split.feature_column + num_dense
154          inequality_test = node.inequality_left_child_test
155          inequality_test.feature_id.id.value = (
156              _SPARSE_FLOAT_FEATURE_NAME_TEMPLATE %
157              (sorted_feature_names[feature_id], split.dimension_id))
158          model_and_features.features.pop(sorted_feature_names[feature_id])
159          (model_and_features.features[inequality_test.feature_id.id.value]
160           .SetInParent())
161          inequality_test.type = (
162              generic_tree_model_pb2.InequalityTest.LESS_OR_EQUAL)
163          inequality_test.threshold.float_value = split.threshold
164        elif node_type == "sparse_float_binary_split_default_right":
165          split = gtflow_node.sparse_float_binary_split_default_right.split
166          node.default_direction = (
167              generic_tree_model_pb2.BinaryNode.RIGHT)
168          # TODO(nponomareva): adjust this id assignement when we allow multi-
169          # column sparse tensors.
170          feature_id = split.feature_column + num_dense
171          inequality_test = node.inequality_left_child_test
172          inequality_test.feature_id.id.value = (
173              _SPARSE_FLOAT_FEATURE_NAME_TEMPLATE %
174              (sorted_feature_names[feature_id], split.dimension_id))
175          model_and_features.features.pop(sorted_feature_names[feature_id])
176          (model_and_features.features[inequality_test.feature_id.id.value]
177           .SetInParent())
178          inequality_test.type = (
179              generic_tree_model_pb2.InequalityTest.LESS_OR_EQUAL)
180          inequality_test.threshold.float_value = split.threshold
181        elif node_type == "categorical_id_binary_split":
182          split = gtflow_node.categorical_id_binary_split
183          node.default_direction = generic_tree_model_pb2.BinaryNode.RIGHT
184          feature_id = split.feature_column + num_dense + num_sparse_float
185          categorical_test = (
186              generic_tree_model_extensions_pb2.MatchingValuesTest())
187          categorical_test.feature_id.id.value = sorted_feature_names[
188              feature_id]
189          matching_id = categorical_test.value.add()
190          matching_id.int64_value = split.feature_id
191          node.custom_left_child_test.Pack(categorical_test)
192        else:
193          raise ValueError("Unexpected node type %s", node_type)
194        node.left_child_id.value = split.left_id
195        node.right_child_id.value = split.right_id
196  return model_and_features
197
198
199def _get_feature_importances(dtec, feature_names, num_dense_floats,
200                             num_sparse_float, num_sparse_int):
201  """Export the feature importance per feature column."""
202  del num_sparse_int    # Unused.
203  sums = collections.defaultdict(lambda: 0)
204  for tree_idx in range(len(dtec.trees)):
205    tree = dtec.trees[tree_idx]
206    for tree_node in tree.nodes:
207      node_type = tree_node.WhichOneof("node")
208      if node_type == "dense_float_binary_split":
209        split = tree_node.dense_float_binary_split
210        split_column = feature_names[split.feature_column]
211      elif node_type == "sparse_float_binary_split_default_left":
212        split = tree_node.sparse_float_binary_split_default_left.split
213        split_column = _SPARSE_FLOAT_FEATURE_NAME_TEMPLATE % (
214            feature_names[split.feature_column + num_dense_floats],
215            split.dimension_id)
216      elif node_type == "sparse_float_binary_split_default_right":
217        split = tree_node.sparse_float_binary_split_default_right.split
218        split_column = _SPARSE_FLOAT_FEATURE_NAME_TEMPLATE % (
219            feature_names[split.feature_column + num_dense_floats],
220            split.dimension_id)
221      elif node_type == "categorical_id_binary_split":
222        split = tree_node.categorical_id_binary_split
223        split_column = feature_names[split.feature_column + num_dense_floats +
224                                     num_sparse_float]
225      elif node_type == "categorical_id_set_membership_binary_split":
226        split = tree_node.categorical_id_set_membership_binary_split
227        split_column = feature_names[split.feature_column + num_dense_floats +
228                                     num_sparse_float]
229      elif node_type == "leaf":
230        assert tree_node.node_metadata.gain == 0
231        continue
232      else:
233        raise ValueError("Unexpected split type %s", node_type)
234      # Apply shrinkage factor. It is important since it is not always uniform
235      # across different trees.
236      sums[split_column] += (
237          tree_node.node_metadata.gain * dtec.tree_weights[tree_idx])
238  return dict(sums)
239