1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Model ops python wrappers.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20# pylint: disable=unused-import 21from tensorflow.contrib.boosted_trees.python.ops import boosted_trees_ops_loader 22# pylint: enable=unused-import 23from tensorflow.contrib.boosted_trees.python.ops import gen_model_ops 24from tensorflow.contrib.boosted_trees.python.ops.gen_model_ops import tree_ensemble_deserialize 25from tensorflow.contrib.boosted_trees.python.ops.gen_model_ops import tree_ensemble_serialize 26# pylint: disable=unused-import 27from tensorflow.contrib.boosted_trees.python.ops.gen_model_ops import tree_ensemble_stamp_token 28# pylint: enable=unused-import 29 30from tensorflow.python.framework import ops 31from tensorflow.python.ops import resources 32from tensorflow.python.training import saver 33 34ops.NotDifferentiable("TreeEnsembleVariable") 35ops.NotDifferentiable("TreeEnsembleSerialize") 36ops.NotDifferentiable("TreeEnsembleDeserialize") 37 38 39class TreeEnsembleVariableSavable(saver.BaseSaverBuilder.SaveableObject): 40 """SaveableObject implementation for TreeEnsembleVariable.""" 41 42 def __init__(self, tree_ensemble_handle, create_op, name): 43 """Creates a TreeEnsembleVariableSavable object. 44 45 Args: 46 tree_ensemble_handle: handle to the tree ensemble variable. 47 create_op: the op to initialize the variable. 48 name: the name to save the tree ensemble variable under. 49 """ 50 stamp_token, ensemble_config = tree_ensemble_serialize(tree_ensemble_handle) 51 # slice_spec is useful for saving a slice from a variable. 52 # It's not meaningful the tree ensemble variable. So we just pass an empty 53 # value. 54 slice_spec = "" 55 specs = [ 56 saver.BaseSaverBuilder.SaveSpec(stamp_token, slice_spec, 57 name + "_stamp"), 58 saver.BaseSaverBuilder.SaveSpec(ensemble_config, slice_spec, 59 name + "_config"), 60 ] 61 super(TreeEnsembleVariableSavable, 62 self).__init__(tree_ensemble_handle, specs, name) 63 self._tree_ensemble_handle = tree_ensemble_handle 64 self._create_op = create_op 65 66 def restore(self, restored_tensors, unused_restored_shapes): 67 """Restores the associated tree ensemble from 'restored_tensors'. 68 69 Args: 70 restored_tensors: the tensors that were loaded from a checkpoint. 71 unused_restored_shapes: the shapes this object should conform to after 72 restore. Not meaningful for trees. 73 74 Returns: 75 The operation that restores the state of the tree ensemble variable. 76 """ 77 with ops.control_dependencies([self._create_op]): 78 return tree_ensemble_deserialize( 79 self._tree_ensemble_handle, 80 stamp_token=restored_tensors[0], 81 tree_ensemble_config=restored_tensors[1]) 82 83 84def tree_ensemble_variable(stamp_token, 85 tree_ensemble_config, 86 name, 87 container=None): 88 r"""Creates a tree ensemble model and returns a handle to it. 89 90 Args: 91 stamp_token: The initial stamp token value for the ensemble resource. 92 tree_ensemble_config: A `Tensor` of type `string`. 93 Serialized proto of the tree ensemble. 94 name: A name for the ensemble variable. 95 container: An optional `string`. Defaults to `""`. 96 97 Returns: 98 A `Tensor` of type mutable `string`. The handle to the tree ensemble. 99 """ 100 with ops.name_scope(name, "TreeEnsembleVariable") as name: 101 resource_handle = gen_model_ops.decision_tree_ensemble_resource_handle_op( 102 container, shared_name=name, name=name) 103 create_op = gen_model_ops.create_tree_ensemble_variable( 104 resource_handle, stamp_token, tree_ensemble_config) 105 is_initialized_op = gen_model_ops.tree_ensemble_is_initialized_op( 106 resource_handle) 107 # Adds the variable to the savable list. 108 saveable = TreeEnsembleVariableSavable(resource_handle, create_op, 109 resource_handle.name) 110 ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) 111 resources.register_resource(resource_handle, create_op, is_initialized_op) 112 return resource_handle 113