1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Ftrl-proximal for TensorFlow."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from tensorflow.python.framework import constant_op
21from tensorflow.python.framework import ops
22from tensorflow.python.ops import math_ops
23from tensorflow.python.training import optimizer
24from tensorflow.python.training import training_ops
25from tensorflow.python.util.tf_export import tf_export
26
27
28@tf_export("train.FtrlOptimizer")
29class FtrlOptimizer(optimizer.Optimizer):
30  """Optimizer that implements the FTRL algorithm.
31
32  See this [paper](
33  https://www.eecs.tufts.edu/~dsculley/papers/ad-click-prediction.pdf).
34  This version has support for both online L2 (the L2 penalty given in the paper
35  above) and shrinkage-type L2 (which is the addition of an L2 penalty to the
36  loss function).
37  """
38
39  def __init__(self,
40               learning_rate,
41               learning_rate_power=-0.5,
42               initial_accumulator_value=0.1,
43               l1_regularization_strength=0.0,
44               l2_regularization_strength=0.0,
45               use_locking=False,
46               name="Ftrl",
47               accum_name=None,
48               linear_name=None,
49               l2_shrinkage_regularization_strength=0.0):
50    r"""Construct a new FTRL optimizer.
51
52    Args:
53      learning_rate: A float value or a constant float `Tensor`.
54      learning_rate_power: A float value, must be less or equal to zero.
55      initial_accumulator_value: The starting value for accumulators.
56        Only positive values are allowed.
57      l1_regularization_strength: A float value, must be greater than or
58        equal to zero.
59      l2_regularization_strength: A float value, must be greater than or
60        equal to zero.
61      use_locking: If `True` use locks for update operations.
62      name: Optional name prefix for the operations created when applying
63        gradients.  Defaults to "Ftrl".
64      accum_name: The suffix for the variable that keeps the gradient squared
65        accumulator.  If not present, defaults to name.
66      linear_name: The suffix for the variable that keeps the linear gradient
67        accumulator.  If not present, defaults to name + "_1".
68      l2_shrinkage_regularization_strength: A float value, must be greater than
69        or equal to zero. This differs from L2 above in that the L2 above is a
70        stabilization penalty, whereas this L2 shrinkage is a magnitude penalty.
71        The FTRL formulation can be written as:
72        w_{t+1} = argmin_w(\hat{g}_{1:t}w + L1*||w||_1 + L2*||w||_2^2), where
73        \hat{g} = g + (2*L2_shrinkage*w), and g is the gradient of the loss
74        function w.r.t. the weights w.
75        Specifically, in the absence of L1 regularization, it is equivalent to
76        the following update rule:
77        w_{t+1} = w_t - lr_t / (1 + 2*L2*lr_t) * g_t -
78                  2*L2_shrinkage*lr_t / (1 + 2*L2*lr_t) * w_t
79        where lr_t is the learning rate at t.
80        When input is sparse shrinkage will only happen on the active weights.
81
82    Raises:
83      ValueError: If one of the arguments is invalid.
84    """
85    super(FtrlOptimizer, self).__init__(use_locking, name)
86
87    if initial_accumulator_value <= 0.0:
88      raise ValueError("initial_accumulator_value %f needs to be positive" %
89                       initial_accumulator_value)
90    if learning_rate_power > 0.0:
91      raise ValueError("learning_rate_power %f needs to be negative or zero" %
92                       learning_rate_power)
93    if l1_regularization_strength < 0.0:
94      raise ValueError(
95          "l1_regularization_strength %f needs to be positive or zero" %
96          l1_regularization_strength)
97    if l2_regularization_strength < 0.0:
98      raise ValueError(
99          "l2_regularization_strength %f needs to be positive or zero" %
100          l2_regularization_strength)
101    if l2_shrinkage_regularization_strength < 0.0:
102      raise ValueError(
103          "l2_shrinkage_regularization_strength %f needs to be positive"
104          " or zero" % l2_shrinkage_regularization_strength)
105
106    self._learning_rate = learning_rate
107    self._learning_rate_power = learning_rate_power
108    self._initial_accumulator_value = initial_accumulator_value
109    self._l1_regularization_strength = l1_regularization_strength
110    self._l2_regularization_strength = l2_regularization_strength
111    self._l2_shrinkage_regularization_strength = (
112        l2_shrinkage_regularization_strength)
113    self._learning_rate_tensor = None
114    self._learning_rate_power_tensor = None
115    self._l1_regularization_strength_tensor = None
116    self._l2_regularization_strength_tensor = None
117    self._l2_shrinkage_regularization_strength_tensor = None
118    self._accum_name = accum_name
119    self._linear_name = linear_name
120
121  def _create_slots(self, var_list):
122    # Create the "accum" and "linear" slots.
123    for v in var_list:
124      with ops.colocate_with(v):
125        val = constant_op.constant(
126            self._initial_accumulator_value, dtype=v.dtype, shape=v.get_shape())
127        self._get_or_make_slot(v, val, "accum", self._accum_name or self._name)
128        self._zeros_slot(v, "linear", self._linear_name or self._name)
129
130  def _prepare(self):
131    self._learning_rate_tensor = ops.convert_to_tensor(
132        self._learning_rate, name="learning_rate")
133    self._l1_regularization_strength_tensor = ops.convert_to_tensor(
134        self._l1_regularization_strength, name="l1_regularization_strength")
135    self._l2_regularization_strength_tensor = ops.convert_to_tensor(
136        self._l2_regularization_strength, name="l2_regularization_strength")
137    self._l2_shrinkage_regularization_strength_tensor = ops.convert_to_tensor(
138        self._l2_shrinkage_regularization_strength,
139        name="l2_shrinkage_regularization_strength")
140    self._learning_rate_power_tensor = ops.convert_to_tensor(
141        self._learning_rate_power, name="learning_rate_power")
142
143  def _apply_dense(self, grad, var):
144    accum = self.get_slot(var, "accum")
145    linear = self.get_slot(var, "linear")
146    if self._l2_shrinkage_regularization_strength <= 0.0:
147      return training_ops.apply_ftrl(
148          var,
149          accum,
150          linear,
151          grad,
152          math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
153          math_ops.cast(self._l1_regularization_strength_tensor,
154                        var.dtype.base_dtype),
155          math_ops.cast(self._l2_regularization_strength_tensor,
156                        var.dtype.base_dtype),
157          math_ops.cast(self._learning_rate_power_tensor, var.dtype.base_dtype),
158          use_locking=self._use_locking)
159    else:
160      return training_ops.apply_ftrl_v2(
161          var,
162          accum,
163          linear,
164          grad,
165          math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
166          math_ops.cast(self._l1_regularization_strength_tensor,
167                        var.dtype.base_dtype),
168          math_ops.cast(self._l2_regularization_strength_tensor,
169                        var.dtype.base_dtype),
170          math_ops.cast(self._l2_shrinkage_regularization_strength_tensor,
171                        var.dtype.base_dtype),
172          math_ops.cast(self._learning_rate_power_tensor, var.dtype.base_dtype),
173          use_locking=self._use_locking)
174
175  def _resource_apply_dense(self, grad, var):
176    accum = self.get_slot(var, "accum")
177    linear = self.get_slot(var, "linear")
178    if self._l2_shrinkage_regularization_strength <= 0.0:
179      return training_ops.resource_apply_ftrl(
180          var.handle,
181          accum.handle,
182          linear.handle,
183          grad,
184          math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
185          math_ops.cast(self._l1_regularization_strength_tensor,
186                        var.dtype.base_dtype),
187          math_ops.cast(self._l2_regularization_strength_tensor,
188                        var.dtype.base_dtype),
189          math_ops.cast(self._learning_rate_power_tensor, var.dtype.base_dtype),
190          use_locking=self._use_locking)
191    else:
192      return training_ops.resource_apply_ftrl_v2(
193          var.handle,
194          accum.handle,
195          linear.handle,
196          grad,
197          math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
198          math_ops.cast(self._l1_regularization_strength_tensor,
199                        var.dtype.base_dtype),
200          math_ops.cast(self._l2_regularization_strength_tensor,
201                        var.dtype.base_dtype),
202          math_ops.cast(self._l2_shrinkage_regularization_strength_tensor,
203                        var.dtype.base_dtype),
204          math_ops.cast(self._learning_rate_power_tensor, var.dtype.base_dtype),
205          use_locking=self._use_locking)
206
207  def _apply_sparse(self, grad, var):
208    accum = self.get_slot(var, "accum")
209    linear = self.get_slot(var, "linear")
210    if self._l2_shrinkage_regularization_strength <= 0.0:
211      return training_ops.sparse_apply_ftrl(
212          var,
213          accum,
214          linear,
215          grad.values,
216          grad.indices,
217          math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
218          math_ops.cast(self._l1_regularization_strength_tensor,
219                        var.dtype.base_dtype),
220          math_ops.cast(self._l2_regularization_strength_tensor,
221                        var.dtype.base_dtype),
222          math_ops.cast(self._learning_rate_power_tensor, var.dtype.base_dtype),
223          use_locking=self._use_locking)
224    else:
225      return training_ops.sparse_apply_ftrl_v2(
226          var,
227          accum,
228          linear,
229          grad.values,
230          grad.indices,
231          math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
232          math_ops.cast(self._l1_regularization_strength_tensor,
233                        var.dtype.base_dtype),
234          math_ops.cast(self._l2_regularization_strength_tensor,
235                        var.dtype.base_dtype),
236          math_ops.cast(self._l2_shrinkage_regularization_strength_tensor,
237                        grad.dtype.base_dtype),
238          math_ops.cast(self._learning_rate_power_tensor, var.dtype.base_dtype),
239          use_locking=self._use_locking)
240
241  def _resource_apply_sparse(self, grad, var, indices):
242    accum = self.get_slot(var, "accum")
243    linear = self.get_slot(var, "linear")
244    if self._l2_shrinkage_regularization_strength <= 0.0:
245      return training_ops.resource_sparse_apply_ftrl(
246          var.handle,
247          accum.handle,
248          linear.handle,
249          grad,
250          indices,
251          math_ops.cast(self._learning_rate_tensor, grad.dtype),
252          math_ops.cast(self._l1_regularization_strength_tensor, grad.dtype),
253          math_ops.cast(self._l2_regularization_strength_tensor, grad.dtype),
254          math_ops.cast(self._learning_rate_power_tensor, grad.dtype),
255          use_locking=self._use_locking)
256    else:
257      return training_ops.resource_sparse_apply_ftrl_v2(
258          var.handle,
259          accum.handle,
260          linear.handle,
261          grad,
262          indices,
263          math_ops.cast(self._learning_rate_tensor, grad.dtype),
264          math_ops.cast(self._l1_regularization_strength_tensor, grad.dtype),
265          math_ops.cast(self._l2_regularization_strength_tensor, grad.dtype),
266          math_ops.cast(self._l2_shrinkage_regularization_strength_tensor,
267                        grad.dtype),
268          math_ops.cast(self._learning_rate_power_tensor, grad.dtype),
269          use_locking=self._use_locking)
270