1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for the GTFlow model ops. 16 17The tests cover: 18- Loading a model from protobufs. 19- Running Predictions using an existing model. 20- Serializing the model. 21""" 22 23from __future__ import absolute_import 24from __future__ import division 25from __future__ import print_function 26 27import os 28 29import numpy as np 30 31from tensorflow.contrib.boosted_trees.proto import learner_pb2 32from tensorflow.contrib.boosted_trees.proto import tree_config_pb2 33from tensorflow.contrib.boosted_trees.python.ops import model_ops 34from tensorflow.contrib.boosted_trees.python.ops import prediction_ops 35from tensorflow.python.framework import ops 36from tensorflow.python.framework import test_util 37from tensorflow.python.ops import resources 38from tensorflow.python.ops import variables 39from tensorflow.python.platform import googletest 40from tensorflow.python.training import saver 41 42 43def _append_to_leaf(leaf, c_id, w): 44 """Helper method for building tree leaves. 45 46 Appends weight contributions for the given class index to a leaf node. 47 48 Args: 49 leaf: leaf node to append to. 50 c_id: class Id for the weight update. 51 w: weight contribution value. 52 """ 53 leaf.sparse_vector.index.append(c_id) 54 leaf.sparse_vector.value.append(w) 55 56 57def _set_float_split(split, feat_col, thresh, l_id, r_id): 58 """Helper method for building tree float splits. 59 60 Sets split feature column, threshold and children. 61 62 Args: 63 split: split node to update. 64 feat_col: feature column for the split. 65 thresh: threshold to split on forming rule x <= thresh. 66 l_id: left child Id. 67 r_id: right child Id. 68 """ 69 split.feature_column = feat_col 70 split.threshold = thresh 71 split.left_id = l_id 72 split.right_id = r_id 73 74 75class ModelOpsTest(test_util.TensorFlowTestCase): 76 77 def setUp(self): 78 """Sets up test for model_ops. 79 80 Create a batch of two examples having one dense float, two sparse float and 81 one sparse int features. 82 The data looks like the following: 83 | Instance | Dense0 | SparseF0 | SparseF1 | SparseI0 | 84 | 0 | 7 | -3 | | | 85 | 1 | -2 | | 4 | 9,1 | 86 """ 87 super(ModelOpsTest, self).setUp() 88 self._dense_float_tensor = np.array([[7.0], [-2.0]]) 89 self._sparse_float_indices1 = np.array([[0, 0]]) 90 self._sparse_float_values1 = np.array([-3.0]) 91 self._sparse_float_shape1 = np.array([2, 1]) 92 self._sparse_float_indices2 = np.array([[1, 0]]) 93 self._sparse_float_values2 = np.array([4.0]) 94 self._sparse_float_shape2 = np.array([2, 1]) 95 self._sparse_int_indices1 = np.array([[1, 0], [1, 1]]) 96 self._sparse_int_values1 = np.array([9, 1]) 97 self._sparse_int_shape1 = np.array([2, 2]) 98 self._seed = 123 99 100 def testCreate(self): 101 with self.test_session(): 102 tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() 103 tree = tree_ensemble_config.trees.add() 104 _append_to_leaf(tree.nodes.add().leaf, 0, -0.4) 105 tree_ensemble_config.tree_weights.append(1.0) 106 107 # Prepare learner config. 108 learner_config = learner_pb2.LearnerConfig() 109 learner_config.num_classes = 2 110 111 tree_ensemble_handle = model_ops.tree_ensemble_variable( 112 stamp_token=3, 113 tree_ensemble_config=tree_ensemble_config.SerializeToString(), 114 name="create_tree") 115 resources.initialize_resources(resources.shared_resources()).run() 116 117 result, _ = prediction_ops.gradient_trees_prediction( 118 tree_ensemble_handle, 119 self._seed, [self._dense_float_tensor], [ 120 self._sparse_float_indices1, self._sparse_float_indices2 121 ], [self._sparse_float_values1, self._sparse_float_values2], 122 [self._sparse_float_shape1, 123 self._sparse_float_shape2], [self._sparse_int_indices1], 124 [self._sparse_int_values1], [self._sparse_int_shape1], 125 learner_config=learner_config.SerializeToString(), 126 apply_dropout=False, 127 apply_averaging=False, 128 center_bias=False, 129 reduce_dim=True) 130 self.assertAllClose(result.eval(), [[-0.4], [-0.4]]) 131 stamp_token = model_ops.tree_ensemble_stamp_token(tree_ensemble_handle) 132 self.assertEqual(stamp_token.eval(), 3) 133 134 def testSerialization(self): 135 with ops.Graph().as_default() as graph: 136 with self.test_session(graph): 137 tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() 138 # Bias tree only for second class. 139 tree1 = tree_ensemble_config.trees.add() 140 _append_to_leaf(tree1.nodes.add().leaf, 1, -0.2) 141 142 tree_ensemble_config.tree_weights.append(1.0) 143 144 # Depth 2 tree. 145 tree2 = tree_ensemble_config.trees.add() 146 tree_ensemble_config.tree_weights.append(1.0) 147 _set_float_split(tree2.nodes.add() 148 .sparse_float_binary_split_default_right.split, 1, 4.0, 149 1, 2) 150 _set_float_split(tree2.nodes.add().dense_float_binary_split, 0, 9.0, 3, 151 4) 152 _append_to_leaf(tree2.nodes.add().leaf, 0, 0.5) 153 _append_to_leaf(tree2.nodes.add().leaf, 1, 1.2) 154 _append_to_leaf(tree2.nodes.add().leaf, 0, -0.9) 155 156 tree_ensemble_handle = model_ops.tree_ensemble_variable( 157 stamp_token=7, 158 tree_ensemble_config=tree_ensemble_config.SerializeToString(), 159 name="saver_tree") 160 stamp_token, serialized_config = model_ops.tree_ensemble_serialize( 161 tree_ensemble_handle) 162 resources.initialize_resources(resources.shared_resources()).run() 163 self.assertEqual(stamp_token.eval(), 7) 164 serialized_config = serialized_config.eval() 165 166 with ops.Graph().as_default() as graph: 167 with self.test_session(graph): 168 tree_ensemble_handle2 = model_ops.tree_ensemble_variable( 169 stamp_token=9, 170 tree_ensemble_config=serialized_config, 171 name="saver_tree2") 172 resources.initialize_resources(resources.shared_resources()).run() 173 174 # Prepare learner config. 175 learner_config = learner_pb2.LearnerConfig() 176 learner_config.num_classes = 3 177 178 result, _ = prediction_ops.gradient_trees_prediction( 179 tree_ensemble_handle2, 180 self._seed, [self._dense_float_tensor], [ 181 self._sparse_float_indices1, self._sparse_float_indices2 182 ], [self._sparse_float_values1, self._sparse_float_values2], 183 [self._sparse_float_shape1, 184 self._sparse_float_shape2], [self._sparse_int_indices1], 185 [self._sparse_int_values1], [self._sparse_int_shape1], 186 learner_config=learner_config.SerializeToString(), 187 apply_dropout=False, 188 apply_averaging=False, 189 center_bias=False, 190 reduce_dim=True) 191 192 # Re-serialize tree. 193 stamp_token2, serialized_config2 = model_ops.tree_ensemble_serialize( 194 tree_ensemble_handle2) 195 196 # The first example will get bias class 1 -0.2 from first tree and 197 # leaf 2 payload (sparse feature missing) of 0.5 hence [0.5, -0.2], 198 # the second example will get the same bias class 1 -0.2 and leaf 3 199 # payload of class 1 1.2 hence [0.0, 1.0]. 200 self.assertEqual(stamp_token2.eval(), 9) 201 202 # Class 2 does have scores in the leaf => it gets score 0. 203 self.assertEqual(serialized_config2.eval(), serialized_config) 204 self.assertAllClose(result.eval(), [[0.5, -0.2], [0, 1.0]]) 205 206 def testRestore(self): 207 # Calling self.test_session() without a graph specified results in 208 # TensorFlowTestCase caching the session and returning the same one 209 # every time. In this test, we need to create two different sessions 210 # which is why we also create a graph and pass it to self.test_session() 211 # to ensure no caching occurs under the hood. 212 save_path = os.path.join(self.get_temp_dir(), "restore-test") 213 with ops.Graph().as_default() as graph: 214 with self.test_session(graph) as sess: 215 # Prepare learner config. 216 learner_config = learner_pb2.LearnerConfig() 217 learner_config.num_classes = 2 218 219 # Add the first tree and save. 220 tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() 221 tree = tree_ensemble_config.trees.add() 222 tree_ensemble_config.tree_metadata.add().is_finalized = True 223 tree_ensemble_config.tree_weights.append(1.0) 224 _append_to_leaf(tree.nodes.add().leaf, 0, -0.1) 225 tree_ensemble_handle = model_ops.tree_ensemble_variable( 226 stamp_token=3, 227 tree_ensemble_config=tree_ensemble_config.SerializeToString(), 228 name="restore_tree") 229 resources.initialize_resources(resources.shared_resources()).run() 230 variables.initialize_all_variables().run() 231 my_saver = saver.Saver() 232 233 # Add the second tree and replace the ensemble of the handle. 234 tree2 = tree_ensemble_config.trees.add() 235 tree_ensemble_config.tree_weights.append(1.0) 236 _append_to_leaf(tree2.nodes.add().leaf, 0, -1.0) 237 # Predict to confirm. 238 with ops.control_dependencies([ 239 model_ops.tree_ensemble_deserialize( 240 tree_ensemble_handle, 241 stamp_token=3, 242 tree_ensemble_config=tree_ensemble_config.SerializeToString()) 243 ]): 244 result, _ = prediction_ops.gradient_trees_prediction( 245 tree_ensemble_handle, 246 self._seed, [self._dense_float_tensor], [ 247 self._sparse_float_indices1, self._sparse_float_indices2 248 ], [self._sparse_float_values1, self._sparse_float_values2], 249 [self._sparse_float_shape1, 250 self._sparse_float_shape2], [self._sparse_int_indices1], 251 [self._sparse_int_values1], [self._sparse_int_shape1], 252 learner_config=learner_config.SerializeToString(), 253 apply_dropout=False, 254 apply_averaging=False, 255 center_bias=False, 256 reduce_dim=True) 257 self.assertAllClose([[-1.1], [-1.1]], result.eval()) 258 # Save before adding other trees. 259 val = my_saver.save(sess, save_path) 260 self.assertEqual(save_path, val) 261 262 # Add more trees after saving. 263 tree3 = tree_ensemble_config.trees.add() 264 tree_ensemble_config.tree_weights.append(1.0) 265 _append_to_leaf(tree3.nodes.add().leaf, 0, -10.0) 266 # Predict to confirm. 267 with ops.control_dependencies([ 268 model_ops.tree_ensemble_deserialize( 269 tree_ensemble_handle, 270 stamp_token=3, 271 tree_ensemble_config=tree_ensemble_config.SerializeToString()) 272 ]): 273 result, _ = prediction_ops.gradient_trees_prediction( 274 tree_ensemble_handle, 275 self._seed, [self._dense_float_tensor], [ 276 self._sparse_float_indices1, self._sparse_float_indices2 277 ], [self._sparse_float_values1, self._sparse_float_values2], 278 [self._sparse_float_shape1, 279 self._sparse_float_shape2], [self._sparse_int_indices1], 280 [self._sparse_int_values1], [self._sparse_int_shape1], 281 learner_config=learner_config.SerializeToString(), 282 apply_dropout=False, 283 apply_averaging=False, 284 center_bias=False, 285 reduce_dim=True) 286 self.assertAllClose(result.eval(), [[-11.1], [-11.1]]) 287 288 # Start a second session. In that session the parameter nodes 289 # have not been initialized either. 290 with ops.Graph().as_default() as graph: 291 with self.test_session(graph) as sess: 292 tree_ensemble_handle = model_ops.tree_ensemble_variable( 293 stamp_token=0, tree_ensemble_config="", name="restore_tree") 294 my_saver = saver.Saver() 295 my_saver.restore(sess, save_path) 296 result, _ = prediction_ops.gradient_trees_prediction( 297 tree_ensemble_handle, 298 self._seed, [self._dense_float_tensor], [ 299 self._sparse_float_indices1, self._sparse_float_indices2 300 ], [self._sparse_float_values1, self._sparse_float_values2], 301 [self._sparse_float_shape1, 302 self._sparse_float_shape2], [self._sparse_int_indices1], 303 [self._sparse_int_values1], [self._sparse_int_shape1], 304 learner_config=learner_config.SerializeToString(), 305 apply_dropout=False, 306 apply_averaging=False, 307 center_bias=False, 308 reduce_dim=True) 309 # Make sure we only have the first and second tree. 310 # The third tree was added after the save. 311 self.assertAllClose(result.eval(), [[-1.1], [-1.1]]) 312 313 314if __name__ == "__main__": 315 googletest.main() 316