1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for the GTFlow model ops.
16
17The tests cover:
18- Loading a model from protobufs.
19- Running Predictions using an existing model.
20- Serializing the model.
21"""
22
23from __future__ import absolute_import
24from __future__ import division
25from __future__ import print_function
26
27import os
28
29import numpy as np
30
31from tensorflow.contrib.boosted_trees.proto import learner_pb2
32from tensorflow.contrib.boosted_trees.proto import tree_config_pb2
33from tensorflow.contrib.boosted_trees.python.ops import model_ops
34from tensorflow.contrib.boosted_trees.python.ops import prediction_ops
35from tensorflow.python.framework import ops
36from tensorflow.python.framework import test_util
37from tensorflow.python.ops import resources
38from tensorflow.python.ops import variables
39from tensorflow.python.platform import googletest
40from tensorflow.python.training import saver
41
42
43def _append_to_leaf(leaf, c_id, w):
44  """Helper method for building tree leaves.
45
46  Appends weight contributions for the given class index to a leaf node.
47
48  Args:
49    leaf: leaf node to append to.
50    c_id: class Id for the weight update.
51    w: weight contribution value.
52  """
53  leaf.sparse_vector.index.append(c_id)
54  leaf.sparse_vector.value.append(w)
55
56
57def _set_float_split(split, feat_col, thresh, l_id, r_id):
58  """Helper method for building tree float splits.
59
60  Sets split feature column, threshold and children.
61
62  Args:
63    split: split node to update.
64    feat_col: feature column for the split.
65    thresh: threshold to split on forming rule x <= thresh.
66    l_id: left child Id.
67    r_id: right child Id.
68  """
69  split.feature_column = feat_col
70  split.threshold = thresh
71  split.left_id = l_id
72  split.right_id = r_id
73
74
75class ModelOpsTest(test_util.TensorFlowTestCase):
76
77  def setUp(self):
78    """Sets up test for model_ops.
79
80    Create a batch of two examples having one dense float, two sparse float and
81    one sparse int features.
82    The data looks like the following:
83    | Instance | Dense0 | SparseF0 | SparseF1 | SparseI0 |
84    | 0        |  7     |    -3    |          |          |
85    | 1        | -2     |          | 4        |   9,1    |
86    """
87    super(ModelOpsTest, self).setUp()
88    self._dense_float_tensor = np.array([[7.0], [-2.0]])
89    self._sparse_float_indices1 = np.array([[0, 0]])
90    self._sparse_float_values1 = np.array([-3.0])
91    self._sparse_float_shape1 = np.array([2, 1])
92    self._sparse_float_indices2 = np.array([[1, 0]])
93    self._sparse_float_values2 = np.array([4.0])
94    self._sparse_float_shape2 = np.array([2, 1])
95    self._sparse_int_indices1 = np.array([[1, 0], [1, 1]])
96    self._sparse_int_values1 = np.array([9, 1])
97    self._sparse_int_shape1 = np.array([2, 2])
98    self._seed = 123
99
100  def testCreate(self):
101    with self.test_session():
102      tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
103      tree = tree_ensemble_config.trees.add()
104      _append_to_leaf(tree.nodes.add().leaf, 0, -0.4)
105      tree_ensemble_config.tree_weights.append(1.0)
106
107      # Prepare learner config.
108      learner_config = learner_pb2.LearnerConfig()
109      learner_config.num_classes = 2
110
111      tree_ensemble_handle = model_ops.tree_ensemble_variable(
112          stamp_token=3,
113          tree_ensemble_config=tree_ensemble_config.SerializeToString(),
114          name="create_tree")
115      resources.initialize_resources(resources.shared_resources()).run()
116
117      result, _ = prediction_ops.gradient_trees_prediction(
118          tree_ensemble_handle,
119          self._seed, [self._dense_float_tensor], [
120              self._sparse_float_indices1, self._sparse_float_indices2
121          ], [self._sparse_float_values1, self._sparse_float_values2],
122          [self._sparse_float_shape1,
123           self._sparse_float_shape2], [self._sparse_int_indices1],
124          [self._sparse_int_values1], [self._sparse_int_shape1],
125          learner_config=learner_config.SerializeToString(),
126          apply_dropout=False,
127          apply_averaging=False,
128          center_bias=False,
129          reduce_dim=True)
130      self.assertAllClose(result.eval(), [[-0.4], [-0.4]])
131      stamp_token = model_ops.tree_ensemble_stamp_token(tree_ensemble_handle)
132      self.assertEqual(stamp_token.eval(), 3)
133
134  def testSerialization(self):
135    with ops.Graph().as_default() as graph:
136      with self.test_session(graph):
137        tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
138        # Bias tree only for second class.
139        tree1 = tree_ensemble_config.trees.add()
140        _append_to_leaf(tree1.nodes.add().leaf, 1, -0.2)
141
142        tree_ensemble_config.tree_weights.append(1.0)
143
144        # Depth 2 tree.
145        tree2 = tree_ensemble_config.trees.add()
146        tree_ensemble_config.tree_weights.append(1.0)
147        _set_float_split(tree2.nodes.add()
148                         .sparse_float_binary_split_default_right.split, 1, 4.0,
149                         1, 2)
150        _set_float_split(tree2.nodes.add().dense_float_binary_split, 0, 9.0, 3,
151                         4)
152        _append_to_leaf(tree2.nodes.add().leaf, 0, 0.5)
153        _append_to_leaf(tree2.nodes.add().leaf, 1, 1.2)
154        _append_to_leaf(tree2.nodes.add().leaf, 0, -0.9)
155
156        tree_ensemble_handle = model_ops.tree_ensemble_variable(
157            stamp_token=7,
158            tree_ensemble_config=tree_ensemble_config.SerializeToString(),
159            name="saver_tree")
160        stamp_token, serialized_config = model_ops.tree_ensemble_serialize(
161            tree_ensemble_handle)
162        resources.initialize_resources(resources.shared_resources()).run()
163        self.assertEqual(stamp_token.eval(), 7)
164        serialized_config = serialized_config.eval()
165
166    with ops.Graph().as_default() as graph:
167      with self.test_session(graph):
168        tree_ensemble_handle2 = model_ops.tree_ensemble_variable(
169            stamp_token=9,
170            tree_ensemble_config=serialized_config,
171            name="saver_tree2")
172        resources.initialize_resources(resources.shared_resources()).run()
173
174        # Prepare learner config.
175        learner_config = learner_pb2.LearnerConfig()
176        learner_config.num_classes = 3
177
178        result, _ = prediction_ops.gradient_trees_prediction(
179            tree_ensemble_handle2,
180            self._seed, [self._dense_float_tensor], [
181                self._sparse_float_indices1, self._sparse_float_indices2
182            ], [self._sparse_float_values1, self._sparse_float_values2],
183            [self._sparse_float_shape1,
184             self._sparse_float_shape2], [self._sparse_int_indices1],
185            [self._sparse_int_values1], [self._sparse_int_shape1],
186            learner_config=learner_config.SerializeToString(),
187            apply_dropout=False,
188            apply_averaging=False,
189            center_bias=False,
190            reduce_dim=True)
191
192        # Re-serialize tree.
193        stamp_token2, serialized_config2 = model_ops.tree_ensemble_serialize(
194            tree_ensemble_handle2)
195
196        # The first example will get bias class 1 -0.2 from first tree and
197        # leaf 2 payload (sparse feature missing) of 0.5 hence [0.5, -0.2],
198        # the second example will get the same bias class 1 -0.2 and leaf 3
199        # payload of class 1 1.2 hence [0.0, 1.0].
200        self.assertEqual(stamp_token2.eval(), 9)
201
202        # Class 2 does have scores in the leaf => it gets score 0.
203        self.assertEqual(serialized_config2.eval(), serialized_config)
204        self.assertAllClose(result.eval(), [[0.5, -0.2], [0, 1.0]])
205
206  def testRestore(self):
207    # Calling self.test_session() without a graph specified results in
208    # TensorFlowTestCase caching the session and returning the same one
209    # every time. In this test, we need to create two different sessions
210    # which is why we also create a graph and pass it to self.test_session()
211    # to ensure no caching occurs under the hood.
212    save_path = os.path.join(self.get_temp_dir(), "restore-test")
213    with ops.Graph().as_default() as graph:
214      with self.test_session(graph) as sess:
215        # Prepare learner config.
216        learner_config = learner_pb2.LearnerConfig()
217        learner_config.num_classes = 2
218
219        # Add the first tree and save.
220        tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
221        tree = tree_ensemble_config.trees.add()
222        tree_ensemble_config.tree_metadata.add().is_finalized = True
223        tree_ensemble_config.tree_weights.append(1.0)
224        _append_to_leaf(tree.nodes.add().leaf, 0, -0.1)
225        tree_ensemble_handle = model_ops.tree_ensemble_variable(
226            stamp_token=3,
227            tree_ensemble_config=tree_ensemble_config.SerializeToString(),
228            name="restore_tree")
229        resources.initialize_resources(resources.shared_resources()).run()
230        variables.initialize_all_variables().run()
231        my_saver = saver.Saver()
232
233        # Add the second tree and replace the ensemble of the handle.
234        tree2 = tree_ensemble_config.trees.add()
235        tree_ensemble_config.tree_weights.append(1.0)
236        _append_to_leaf(tree2.nodes.add().leaf, 0, -1.0)
237        # Predict to confirm.
238        with ops.control_dependencies([
239            model_ops.tree_ensemble_deserialize(
240                tree_ensemble_handle,
241                stamp_token=3,
242                tree_ensemble_config=tree_ensemble_config.SerializeToString())
243        ]):
244          result, _ = prediction_ops.gradient_trees_prediction(
245              tree_ensemble_handle,
246              self._seed, [self._dense_float_tensor], [
247                  self._sparse_float_indices1, self._sparse_float_indices2
248              ], [self._sparse_float_values1, self._sparse_float_values2],
249              [self._sparse_float_shape1,
250               self._sparse_float_shape2], [self._sparse_int_indices1],
251              [self._sparse_int_values1], [self._sparse_int_shape1],
252              learner_config=learner_config.SerializeToString(),
253              apply_dropout=False,
254              apply_averaging=False,
255              center_bias=False,
256              reduce_dim=True)
257        self.assertAllClose([[-1.1], [-1.1]], result.eval())
258        # Save before adding other trees.
259        val = my_saver.save(sess, save_path)
260        self.assertEqual(save_path, val)
261
262        # Add more trees after saving.
263        tree3 = tree_ensemble_config.trees.add()
264        tree_ensemble_config.tree_weights.append(1.0)
265        _append_to_leaf(tree3.nodes.add().leaf, 0, -10.0)
266        # Predict to confirm.
267        with ops.control_dependencies([
268            model_ops.tree_ensemble_deserialize(
269                tree_ensemble_handle,
270                stamp_token=3,
271                tree_ensemble_config=tree_ensemble_config.SerializeToString())
272        ]):
273          result, _ = prediction_ops.gradient_trees_prediction(
274              tree_ensemble_handle,
275              self._seed, [self._dense_float_tensor], [
276                  self._sparse_float_indices1, self._sparse_float_indices2
277              ], [self._sparse_float_values1, self._sparse_float_values2],
278              [self._sparse_float_shape1,
279               self._sparse_float_shape2], [self._sparse_int_indices1],
280              [self._sparse_int_values1], [self._sparse_int_shape1],
281              learner_config=learner_config.SerializeToString(),
282              apply_dropout=False,
283              apply_averaging=False,
284              center_bias=False,
285              reduce_dim=True)
286        self.assertAllClose(result.eval(), [[-11.1], [-11.1]])
287
288    # Start a second session.  In that session the parameter nodes
289    # have not been initialized either.
290    with ops.Graph().as_default() as graph:
291      with self.test_session(graph) as sess:
292        tree_ensemble_handle = model_ops.tree_ensemble_variable(
293            stamp_token=0, tree_ensemble_config="", name="restore_tree")
294        my_saver = saver.Saver()
295        my_saver.restore(sess, save_path)
296        result, _ = prediction_ops.gradient_trees_prediction(
297            tree_ensemble_handle,
298            self._seed, [self._dense_float_tensor], [
299                self._sparse_float_indices1, self._sparse_float_indices2
300            ], [self._sparse_float_values1, self._sparse_float_values2],
301            [self._sparse_float_shape1,
302             self._sparse_float_shape2], [self._sparse_int_indices1],
303            [self._sparse_int_values1], [self._sparse_int_shape1],
304            learner_config=learner_config.SerializeToString(),
305            apply_dropout=False,
306            apply_averaging=False,
307            center_bias=False,
308            reduce_dim=True)
309        # Make sure we only have the first and second tree.
310        # The third tree was added after the save.
311        self.assertAllClose(result.eval(), [[-1.1], [-1.1]])
312
313
314if __name__ == "__main__":
315  googletest.main()
316