1122cdce33e3e0a01a7f82645617317530aa571fbA. Unique TensorFlower/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower 3cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlowerLicensed under the Apache License, Version 2.0 (the "License"); 4cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFloweryou may not use this file except in compliance with the License. 5cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlowerYou may obtain a copy of the License at 6cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower 7cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower http://www.apache.org/licenses/LICENSE-2.0 8cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower 9cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlowerUnless required by applicable law or agreed to in writing, software 10cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlowerdistributed under the License is distributed on an "AS IS" BASIS, 11cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlowerWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlowerSee the License for the specific language governing permissions and 13cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlowerlimitations under the License. 14cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower==============================================================================*/ 15cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower 16eb4421e17157a80ae97f258a2c935e0ae00f1ed9A. Unique TensorFlower#ifndef TENSORFLOW_KERNELS_HINGE_LOSS_H_ 17eb4421e17157a80ae97f258a2c935e0ae00f1ed9A. Unique TensorFlower#define TENSORFLOW_KERNELS_HINGE_LOSS_H_ 18cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower 19cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower#include <algorithm> 20a075f1014dc9736d50b79ba0de62112030df2c5fA. Unique TensorFlower#include <limits> 21cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower 22eb4421e17157a80ae97f258a2c935e0ae00f1ed9A. Unique TensorFlower#include "tensorflow/core/kernels/loss.h" 239ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower#include "tensorflow/core/lib/core/errors.h" 24a075f1014dc9736d50b79ba0de62112030df2c5fA. Unique TensorFlower#include "tensorflow/core/lib/core/status.h" 259ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower 26cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlowernamespace tensorflow { 279ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower 289ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlowerclass HingeLossUpdater : public DualLossUpdater { 299ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower public: 30cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower // Computes the updated dual variable (corresponding) to a single example. The 31cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower // updated dual value maximizes the objective function of the dual 32cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower // optimization problem associated with hinge loss (conditioned on keeping the 33cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower // rest of the dual variables intact). The method below finds an optimal delta 34cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower // (difference between updated and previous dual value) using the update rule 35cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower // within SDCA procedure (see http://arxiv.org/pdf/1209.1873v2.pdf, page 5) 36cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower // and the particular form of conjugate function for hinge loss. 376e6d0dcde7c59f306c6f276ed153128fe8bdd8bfA. Unique TensorFlower // 38cf6857d4766aaabd16c5009156c8e69846c17bceA. Unique TensorFlower // The CoCoA+ modification is detailed in readme.md. 39cf6857d4766aaabd16c5009156c8e69846c17bceA. Unique TensorFlower // 40733321f6159f55474f4d7822b5a1c174d30e696dYuan Yu // TODO(sibyl-vie3Poto): Write up a doc with concrete derivation and point to it from 41cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower // here. 421c137bd29dc53541e3699c46f301797dd750b82aA. Unique TensorFlower double ComputeUpdatedDual(const int num_loss_partitions, const double label, 436e6d0dcde7c59f306c6f276ed153128fe8bdd8bfA. Unique TensorFlower const double example_weight, 449ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower const double current_dual, const double wx, 45cf6857d4766aaabd16c5009156c8e69846c17bceA. Unique TensorFlower const double weighted_example_norm) const final { 46cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower // Intutitvely there are 3 cases: 471f6ed06221949c5a9037f2d81f30208d23f3afc6Andreas Solleder // a. new optimal value of the dual variable falls within the admissible 48cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower // range [0, 1]. In this case we set new dual to this value. 49cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower // b. new optimal value is < 0. Then, because of convexity, the optimal 50cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower // valid value for new dual = 0 51cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower // c. new optimal value > 1.0. Then new optimal value should be set to 1.0. 52cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower const double candidate_optimal_dual = 53982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen current_dual + (label - wx) / (num_loss_partitions * example_weight * 54982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen weighted_example_norm); 55cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower if (label * candidate_optimal_dual < 0) { 56cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower return 0.0; 57cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower } 58cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower if (label * candidate_optimal_dual > 1.0) { 59cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower return label; 60cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower } 61cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower return candidate_optimal_dual; 62cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower } 63cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower 64cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower // Conjugate of hinge loss. This is computed as: 65cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower // \phi*(z) = z if z \in [-1, 0] and +infinity everywhere else. See for 66cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower // instance http://www.eecs.berkeley.edu/~wainwrig/stat241b/lec10.pdf 67cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower // Here we want the weighted version of the conjugate loss. It turns out, that 68cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower // if w is the weight of an example, the conjugate of the weighted hinge loss 69cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower // is given by: 70cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower // \phi*(z) = z if z \in [-w, 0] and +infinity everywhere else. Here the 71cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower // conjugate function depends not only on the weight of the example but also 72cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower // on its label. In particular: 73cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower // \phi_y*(z) = y*z if y*z \in [-w, 0] and +infinity everywhere else where 74cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower // y \in {-1,1}. The following method implements \phi_y*(-\alpha/w). 759ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower double ComputeDualLoss(const double current_dual, const double example_label, 769ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower const double example_weight) const final { 77cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower // For binary classification, there are 2 conjugate functions, one per 78cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower // label value (-1 and 1). 79cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower const double y_alpha = current_dual * example_label; // y \alpha 80cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower if (y_alpha < 0 || y_alpha > 1.0) { 81cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower return std::numeric_limits<double>::max(); 82cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower } 83cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower return -y_alpha * example_weight; 84cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower } 85cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower 86cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower // Hinge loss for binary classification for a single example. Hinge loss 87cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower // equals max(0, 1 - y * wx) (see https://en.wikipedia.org/wiki/Hinge_loss). 88cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower // For weighted instances loss should be multiplied by the instance weight. 899ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower double ComputePrimalLoss(const double wx, const double example_label, 909ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower const double example_weight) const final { 91cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower const double y_wx = example_label * wx; 92cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower return std::max(0.0, 1 - y_wx) * example_weight; 93cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower } 949ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower 95f7ed48aac2ab7eeab1198c0e7027a9a9afeb7b5dA. Unique TensorFlower double PrimalLossDerivative(const double wx, const double label, 96f7ed48aac2ab7eeab1198c0e7027a9a9afeb7b5dA. Unique TensorFlower const double example_weight) const final { 97f7ed48aac2ab7eeab1198c0e7027a9a9afeb7b5dA. Unique TensorFlower if (label * wx < 1) { 98f7ed48aac2ab7eeab1198c0e7027a9a9afeb7b5dA. Unique TensorFlower return -label * example_weight; 99f7ed48aac2ab7eeab1198c0e7027a9a9afeb7b5dA. Unique TensorFlower } 100f7ed48aac2ab7eeab1198c0e7027a9a9afeb7b5dA. Unique TensorFlower return 0; 101f7ed48aac2ab7eeab1198c0e7027a9a9afeb7b5dA. Unique TensorFlower } 102f7ed48aac2ab7eeab1198c0e7027a9a9afeb7b5dA. Unique TensorFlower 103f7ed48aac2ab7eeab1198c0e7027a9a9afeb7b5dA. Unique TensorFlower // The smoothness constant is 0 since the derivative of the loss is not 104f7ed48aac2ab7eeab1198c0e7027a9a9afeb7b5dA. Unique TensorFlower // Lipschitz 105f7ed48aac2ab7eeab1198c0e7027a9a9afeb7b5dA. Unique TensorFlower double SmoothnessConstant() const final { return 0; } 106f7ed48aac2ab7eeab1198c0e7027a9a9afeb7b5dA. Unique TensorFlower 1079ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower // Converts binary example labels from 0.0 or 1.0 to -1.0 or 1.0 respectively 1089ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower // as expected by hinge loss. 1099ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower Status ConvertLabel(float* const example_label) const final { 1109ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower if (*example_label == 0.0) { 1119ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower *example_label = -1; 1129ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower return Status::OK(); 1139ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower } 1149ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower if (*example_label == 1.0) { 1159ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower return Status::OK(); 1169ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower } 1179ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower return errors::InvalidArgument( 1189ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower "Only labels of 0.0 or 1.0 are supported right now. " 1199ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower "Found example with label: ", 1209ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower *example_label); 1219ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower } 122cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower}; 1239ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower 124cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower} // namespace tensorflow 125cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower 126eb4421e17157a80ae97f258a2c935e0ae00f1ed9A. Unique TensorFlower#endif // TENSORFLOW_KERNELS_HINGE_LOSS_H_ 127