1122cdce33e3e0a01a7f82645617317530aa571fbA. Unique TensorFlower/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower
3cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlowerLicensed under the Apache License, Version 2.0 (the "License");
4cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFloweryou may not use this file except in compliance with the License.
5cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlowerYou may obtain a copy of the License at
6cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower
7cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower    http://www.apache.org/licenses/LICENSE-2.0
8cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower
9cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlowerUnless required by applicable law or agreed to in writing, software
10cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlowerdistributed under the License is distributed on an "AS IS" BASIS,
11cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlowerWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlowerSee the License for the specific language governing permissions and
13cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlowerlimitations under the License.
14cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower==============================================================================*/
15cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower
16eb4421e17157a80ae97f258a2c935e0ae00f1ed9A. Unique TensorFlower#ifndef TENSORFLOW_KERNELS_HINGE_LOSS_H_
17eb4421e17157a80ae97f258a2c935e0ae00f1ed9A. Unique TensorFlower#define TENSORFLOW_KERNELS_HINGE_LOSS_H_
18cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower
19cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower#include <algorithm>
20a075f1014dc9736d50b79ba0de62112030df2c5fA. Unique TensorFlower#include <limits>
21cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower
22eb4421e17157a80ae97f258a2c935e0ae00f1ed9A. Unique TensorFlower#include "tensorflow/core/kernels/loss.h"
239ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower#include "tensorflow/core/lib/core/errors.h"
24a075f1014dc9736d50b79ba0de62112030df2c5fA. Unique TensorFlower#include "tensorflow/core/lib/core/status.h"
259ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower
26cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlowernamespace tensorflow {
279ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower
289ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlowerclass HingeLossUpdater : public DualLossUpdater {
299ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower public:
30cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower  // Computes the updated dual variable (corresponding) to a single example. The
31cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower  // updated dual value maximizes the objective function of the dual
32cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower  // optimization problem associated with hinge loss (conditioned on keeping the
33cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower  // rest of the dual variables intact). The method below finds an optimal delta
34cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower  // (difference between updated and previous dual value) using the update rule
35cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower  // within SDCA procedure (see http://arxiv.org/pdf/1209.1873v2.pdf, page 5)
36cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower  // and the particular form of conjugate function for hinge loss.
376e6d0dcde7c59f306c6f276ed153128fe8bdd8bfA. Unique TensorFlower  //
38cf6857d4766aaabd16c5009156c8e69846c17bceA. Unique TensorFlower  // The CoCoA+ modification is detailed in readme.md.
39cf6857d4766aaabd16c5009156c8e69846c17bceA. Unique TensorFlower  //
40733321f6159f55474f4d7822b5a1c174d30e696dYuan Yu  // TODO(sibyl-vie3Poto): Write up a doc with concrete derivation and point to it from
41cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower  // here.
421c137bd29dc53541e3699c46f301797dd750b82aA. Unique TensorFlower  double ComputeUpdatedDual(const int num_loss_partitions, const double label,
436e6d0dcde7c59f306c6f276ed153128fe8bdd8bfA. Unique TensorFlower                            const double example_weight,
449ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower                            const double current_dual, const double wx,
45cf6857d4766aaabd16c5009156c8e69846c17bceA. Unique TensorFlower                            const double weighted_example_norm) const final {
46cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower    // Intutitvely there are 3 cases:
471f6ed06221949c5a9037f2d81f30208d23f3afc6Andreas Solleder    // a. new optimal value of the dual variable falls within the admissible
48cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower    // range [0, 1]. In this case we set new dual to this value.
49cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower    // b. new optimal value is < 0. Then, because of convexity, the optimal
50cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower    // valid value for new dual = 0
51cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower    // c. new optimal value > 1.0. Then new optimal value should be set to 1.0.
52cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower    const double candidate_optimal_dual =
53982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen        current_dual + (label - wx) / (num_loss_partitions * example_weight *
54982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                                       weighted_example_norm);
55cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower    if (label * candidate_optimal_dual < 0) {
56cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower      return 0.0;
57cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower    }
58cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower    if (label * candidate_optimal_dual > 1.0) {
59cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower      return label;
60cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower    }
61cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower    return candidate_optimal_dual;
62cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower  }
63cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower
64cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower  // Conjugate of hinge loss. This is computed as:
65cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower  // \phi*(z) = z if z \in [-1, 0] and +infinity everywhere else. See for
66cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower  // instance http://www.eecs.berkeley.edu/~wainwrig/stat241b/lec10.pdf
67cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower  // Here we want the weighted version of the conjugate loss. It turns out, that
68cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower  // if w is the weight of an example, the conjugate of the weighted hinge loss
69cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower  // is given by:
70cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower  // \phi*(z) = z if z \in [-w, 0] and +infinity everywhere else. Here the
71cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower  // conjugate function depends not only on the weight of the example but also
72cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower  // on its label. In particular:
73cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower  // \phi_y*(z) = y*z if y*z \in [-w, 0] and +infinity everywhere else where
74cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower  // y \in {-1,1}. The following method implements \phi_y*(-\alpha/w).
759ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower  double ComputeDualLoss(const double current_dual, const double example_label,
769ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower                         const double example_weight) const final {
77cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower    // For binary classification, there are 2 conjugate functions, one per
78cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower    // label value (-1 and 1).
79cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower    const double y_alpha = current_dual * example_label;  // y \alpha
80cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower    if (y_alpha < 0 || y_alpha > 1.0) {
81cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower      return std::numeric_limits<double>::max();
82cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower    }
83cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower    return -y_alpha * example_weight;
84cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower  }
85cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower
86cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower  // Hinge loss for binary classification for a single example. Hinge loss
87cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower  // equals max(0, 1 - y * wx) (see https://en.wikipedia.org/wiki/Hinge_loss).
88cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower  // For weighted instances loss should be multiplied by the instance weight.
899ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower  double ComputePrimalLoss(const double wx, const double example_label,
909ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower                           const double example_weight) const final {
91cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower    const double y_wx = example_label * wx;
92cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower    return std::max(0.0, 1 - y_wx) * example_weight;
93cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower  }
949ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower
95f7ed48aac2ab7eeab1198c0e7027a9a9afeb7b5dA. Unique TensorFlower  double PrimalLossDerivative(const double wx, const double label,
96f7ed48aac2ab7eeab1198c0e7027a9a9afeb7b5dA. Unique TensorFlower                              const double example_weight) const final {
97f7ed48aac2ab7eeab1198c0e7027a9a9afeb7b5dA. Unique TensorFlower    if (label * wx < 1) {
98f7ed48aac2ab7eeab1198c0e7027a9a9afeb7b5dA. Unique TensorFlower      return -label * example_weight;
99f7ed48aac2ab7eeab1198c0e7027a9a9afeb7b5dA. Unique TensorFlower    }
100f7ed48aac2ab7eeab1198c0e7027a9a9afeb7b5dA. Unique TensorFlower    return 0;
101f7ed48aac2ab7eeab1198c0e7027a9a9afeb7b5dA. Unique TensorFlower  }
102f7ed48aac2ab7eeab1198c0e7027a9a9afeb7b5dA. Unique TensorFlower
103f7ed48aac2ab7eeab1198c0e7027a9a9afeb7b5dA. Unique TensorFlower  // The smoothness constant is 0 since the derivative of the loss is not
104f7ed48aac2ab7eeab1198c0e7027a9a9afeb7b5dA. Unique TensorFlower  // Lipschitz
105f7ed48aac2ab7eeab1198c0e7027a9a9afeb7b5dA. Unique TensorFlower  double SmoothnessConstant() const final { return 0; }
106f7ed48aac2ab7eeab1198c0e7027a9a9afeb7b5dA. Unique TensorFlower
1079ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower  // Converts binary example labels from 0.0 or 1.0 to -1.0 or 1.0 respectively
1089ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower  // as expected by hinge loss.
1099ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower  Status ConvertLabel(float* const example_label) const final {
1109ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower    if (*example_label == 0.0) {
1119ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower      *example_label = -1;
1129ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower      return Status::OK();
1139ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower    }
1149ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower    if (*example_label == 1.0) {
1159ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower      return Status::OK();
1169ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower    }
1179ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower    return errors::InvalidArgument(
1189ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower        "Only labels of 0.0 or 1.0 are supported right now. "
1199ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower        "Found example with label: ",
1209ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower        *example_label);
1219ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower  }
122cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower};
1239ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower
124cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower}  // namespace tensorflow
125cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower
126eb4421e17157a80ae97f258a2c935e0ae00f1ed9A. Unique TensorFlower#endif  // TENSORFLOW_KERNELS_HINGE_LOSS_H_
127