1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Model ops python wrappers."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20# pylint: disable=unused-import
21from tensorflow.contrib.boosted_trees.python.ops import boosted_trees_ops_loader
22# pylint: enable=unused-import
23from tensorflow.contrib.boosted_trees.python.ops import gen_model_ops
24from tensorflow.contrib.boosted_trees.python.ops.gen_model_ops import tree_ensemble_deserialize
25from tensorflow.contrib.boosted_trees.python.ops.gen_model_ops import tree_ensemble_serialize
26# pylint: disable=unused-import
27from tensorflow.contrib.boosted_trees.python.ops.gen_model_ops import tree_ensemble_stamp_token
28# pylint: enable=unused-import
29
30from tensorflow.python.framework import ops
31from tensorflow.python.ops import resources
32from tensorflow.python.training import saver
33
34ops.NotDifferentiable("TreeEnsembleVariable")
35ops.NotDifferentiable("TreeEnsembleSerialize")
36ops.NotDifferentiable("TreeEnsembleDeserialize")
37
38
39class TreeEnsembleVariableSavable(saver.BaseSaverBuilder.SaveableObject):
40  """SaveableObject implementation for TreeEnsembleVariable."""
41
42  def __init__(self, tree_ensemble_handle, create_op, name):
43    """Creates a TreeEnsembleVariableSavable object.
44
45    Args:
46      tree_ensemble_handle: handle to the tree ensemble variable.
47      create_op: the op to initialize the variable.
48      name: the name to save the tree ensemble variable under.
49    """
50    stamp_token, ensemble_config = tree_ensemble_serialize(tree_ensemble_handle)
51    # slice_spec is useful for saving a slice from a variable.
52    # It's not meaningful the tree ensemble variable. So we just pass an empty
53    # value.
54    slice_spec = ""
55    specs = [
56        saver.BaseSaverBuilder.SaveSpec(stamp_token, slice_spec,
57                                        name + "_stamp"),
58        saver.BaseSaverBuilder.SaveSpec(ensemble_config, slice_spec,
59                                        name + "_config"),
60    ]
61    super(TreeEnsembleVariableSavable,
62          self).__init__(tree_ensemble_handle, specs, name)
63    self._tree_ensemble_handle = tree_ensemble_handle
64    self._create_op = create_op
65
66  def restore(self, restored_tensors, unused_restored_shapes):
67    """Restores the associated tree ensemble from 'restored_tensors'.
68
69    Args:
70      restored_tensors: the tensors that were loaded from a checkpoint.
71      unused_restored_shapes: the shapes this object should conform to after
72        restore. Not meaningful for trees.
73
74    Returns:
75      The operation that restores the state of the tree ensemble variable.
76    """
77    with ops.control_dependencies([self._create_op]):
78      return tree_ensemble_deserialize(
79          self._tree_ensemble_handle,
80          stamp_token=restored_tensors[0],
81          tree_ensemble_config=restored_tensors[1])
82
83
84def tree_ensemble_variable(stamp_token,
85                           tree_ensemble_config,
86                           name,
87                           container=None):
88  r"""Creates a tree ensemble model and returns a handle to it.
89
90  Args:
91    stamp_token: The initial stamp token value for the ensemble resource.
92    tree_ensemble_config: A `Tensor` of type `string`.
93      Serialized proto of the tree ensemble.
94    name: A name for the ensemble variable.
95    container: An optional `string`. Defaults to `""`.
96
97  Returns:
98    A `Tensor` of type mutable `string`. The handle to the tree ensemble.
99  """
100  with ops.name_scope(name, "TreeEnsembleVariable") as name:
101    resource_handle = gen_model_ops.decision_tree_ensemble_resource_handle_op(
102        container, shared_name=name, name=name)
103    create_op = gen_model_ops.create_tree_ensemble_variable(
104        resource_handle, stamp_token, tree_ensemble_config)
105    is_initialized_op = gen_model_ops.tree_ensemble_is_initialized_op(
106        resource_handle)
107    # Adds the variable to the savable list.
108    saveable = TreeEnsembleVariableSavable(resource_handle, create_op,
109                                           resource_handle.name)
110    ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
111    resources.register_resource(resource_handle, create_op, is_initialized_op)
112    return resource_handle
113