stochastic_linear_ranker.h revision b019e89cbea221598c482b05ab68b7660b41aa23
16b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua/*
26b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * Copyright (C) 2012 The Android Open Source Project
36b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua *
46b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * Licensed under the Apache License, Version 2.0 (the "License");
56b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * you may not use this file except in compliance with the License.
66b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * You may obtain a copy of the License at
76b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua *
86b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua *      http://www.apache.org/licenses/LICENSE-2.0
96b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua *
106b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * Unless required by applicable law or agreed to in writing, software
116b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * distributed under the License is distributed on an "AS IS" BASIS,
126b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
136b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * See the License for the specific language governing permissions and
146b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * limitations under the License.
156b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua */
166b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
176b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua// Stochastic Linear Ranking algorithms.
186b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua// This class will implement a set of incremental algorithms for ranking tasks
196b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua// They support both L1 and L2 regularizations.
206b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
216b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
226b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#ifndef LEARNING_STOCHASTIC_LINEAR_STOCHASTIC_LINEAR_RANKER_H_
236b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#define LEARNING_STOCHASTIC_LINEAR_STOCHASTIC_LINEAR_RANKER_H_
246b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
256b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#include <cmath>
266b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#include <hash_map>
276b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#include <string>
286b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
296b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#include <sys/types.h>
306b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#include "cutils/log.h"
316b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#include "common_defs.h"
326b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#include "learning_rate_controller-inl.h"
336b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#include "sparse_weight_vector.h"
346b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
356b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huanamespace learning_stochastic_linear {
366b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
376b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua// NOTE: This Stochastic Linear Ranker supports only the following update types:
386b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua// SL: Stochastic Linear
396b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua// CS: Constraint Satisfaction
406b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huatemplate<class Key = std::string, class Hash = std::hash_map<std::string, double> >
416b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huaclass StochasticLinearRanker {
426b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua public:
436b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // initialize lambda_ and constraint to a meaningful default. Will give
446b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // equal weight to the error and regularizer.
456b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  StochasticLinearRanker() {
466b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    iteration_num_ = 0;
476b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    lambda_ = 1.0;
486b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    learning_rate_controller_.SetLambda(lambda_);
496b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    mini_batch_size_ = 1;
506b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    learning_rate_controller_.SetMiniBatchSize(mini_batch_size_);
51b019e89cbea221598c482b05ab68b7660b41aa23saberian    adaptation_mode_ = INV_LINEAR;
52b019e89cbea221598c482b05ab68b7660b41aa23saberian    learning_rate_controller_.SetAdaptationMode(adaptation_mode_);
536b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    update_type_ = SL;
54b019e89cbea221598c482b05ab68b7660b41aa23saberian    regularization_type_ = L2;
556b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    kernel_type_ = LINEAR;
566b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    kernel_param_ = 1.0;
576b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    kernel_gain_ = 1.0;
586b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    kernel_bias_ = 0.0;
596b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    rank_loss_type_ = PAIRWISE;
606b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    acceptence_probability_ = 0.1;
616b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    mini_batch_counter_ = 0;
626b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    gradient_l0_norm_ = -1;
636b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    norm_constraint_ = 1.0;
646b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
656b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
666b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  ~StochasticLinearRanker() {}
676b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // Getters and setters
686b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  double GetIterationNumber() const {
696b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    return iteration_num_;
706b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
716b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  double GetNormContraint() const {
726b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    return norm_constraint_;
736b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
746b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  RegularizationType GetRegularizationType() const {
756b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    return regularization_type_;
766b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
776b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  double GetLambda() const {
786b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    return lambda_;
796b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
806b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  uint64 GetMiniBatchSize() const {
816b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    return mini_batch_size_;
826b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
836b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  int32 GetGradientL0Norm() const {
846b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    return gradient_l0_norm_;
856b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
866b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  UpdateType GetUpdateType() const {
876b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    return update_type_;
886b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
896b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  AdaptationMode GetAdaptationMode() const {
906b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    return adaptation_mode_;
916b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
92b019e89cbea221598c482b05ab68b7660b41aa23saberian  KernelType GetKernelType() const {
93b019e89cbea221598c482b05ab68b7660b41aa23saberian    return kernel_type_;
94b019e89cbea221598c482b05ab68b7660b41aa23saberian  }
95b019e89cbea221598c482b05ab68b7660b41aa23saberian  // This function returns the basic kernel parameter. In case of
966b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // polynomial kernel, it implies the degree of the polynomial.  In case of
976b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // RBF kernel, it implies the sigma parameter. In case of linear kernel,
986b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // it is not used.
99b019e89cbea221598c482b05ab68b7660b41aa23saberian  double GetKernelParam() const {
100b019e89cbea221598c482b05ab68b7660b41aa23saberian    return kernel_param_;
101b019e89cbea221598c482b05ab68b7660b41aa23saberian  }
102b019e89cbea221598c482b05ab68b7660b41aa23saberian  double GetKernelGain() const {
103b019e89cbea221598c482b05ab68b7660b41aa23saberian    return kernel_gain_;;
104b019e89cbea221598c482b05ab68b7660b41aa23saberian  }
105b019e89cbea221598c482b05ab68b7660b41aa23saberian  double GetKernelBias() const {
106b019e89cbea221598c482b05ab68b7660b41aa23saberian    return kernel_bias_;
1076b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
1086b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  RankLossType GetRankLossType() const {
1096b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    return rank_loss_type_;
1106b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
1116b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  double GetAcceptanceProbability() const {
1126b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    return acceptence_probability_;
1136b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
1146b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  void SetIterationNumber(uint64 num) {
1156b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    iteration_num_=num;
1166b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
1176b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  void SetNormConstraint(const double norm) {
1186b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    norm_constraint_ = norm;
1196b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
1206b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  void SetRegularizationType(const RegularizationType r) {
1216b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    regularization_type_ = r;
1226b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
1236b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  void SetLambda(double l) {
1246b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    lambda_ = l;
1256b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    learning_rate_controller_.SetLambda(l);
1266b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
1276b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  void SetMiniBatchSize(const uint64 msize) {
1286b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    mini_batch_size_ = msize;
1296b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    learning_rate_controller_.SetMiniBatchSize(msize);
1306b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
1316b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  void SetAdaptationMode(AdaptationMode m) {
1326b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    adaptation_mode_ = m;
1336b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    learning_rate_controller_.SetAdaptationMode(m);
1346b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
135b019e89cbea221598c482b05ab68b7660b41aa23saberian  void SetKernelType(KernelType k ) {
136b019e89cbea221598c482b05ab68b7660b41aa23saberian    kernel_type_ = k;
137b019e89cbea221598c482b05ab68b7660b41aa23saberian  }
138b019e89cbea221598c482b05ab68b7660b41aa23saberian  // This function sets the basic kernel parameter. In case of
1396b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // polynomial kernel, it implies the degree of the polynomial. In case of
1406b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // RBF kernel, it implies the sigma parameter. In case of linear kernel,
1416b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // it is not used.
142b019e89cbea221598c482b05ab68b7660b41aa23saberian  void SetKernelParam(double param) {
1436b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    kernel_param_ = param;
144b019e89cbea221598c482b05ab68b7660b41aa23saberian  }
145b019e89cbea221598c482b05ab68b7660b41aa23saberian  // This function sets the kernel gain. NOTE: in most use cases, gain should
146b019e89cbea221598c482b05ab68b7660b41aa23saberian  // be set to 1.0.
147b019e89cbea221598c482b05ab68b7660b41aa23saberian  void SetKernelGain(double gain) {
1486b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    kernel_gain_ = gain;
149b019e89cbea221598c482b05ab68b7660b41aa23saberian  }
150b019e89cbea221598c482b05ab68b7660b41aa23saberian  // This function sets the kernel bias. NOTE: in most use cases, bias should
151b019e89cbea221598c482b05ab68b7660b41aa23saberian  // be set to 0.0.
152b019e89cbea221598c482b05ab68b7660b41aa23saberian  void SetKernelBias(double bias) {
1536b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    kernel_bias_ = bias;
1546b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
1556b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  void SetUpdateType(UpdateType u) {
1566b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    update_type_ = u;
1576b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
1586b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  void SetRankLossType(RankLossType r) {
1596b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    rank_loss_type_ = r;
1606b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
1616b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  void SetAcceptanceProbability(double p) {
1626b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    acceptence_probability_ = p;
1636b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
1646b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  void SetGradientL0Norm(const int32 gradient_l0_norm) {
1656b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    gradient_l0_norm_ = gradient_l0_norm;
1666b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
1676b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // Load an existing model
1686b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  void LoadWeights(const SparseWeightVector<Key, Hash> &model) {
1696b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    weight_.LoadWeightVector(model);
1706b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
1716b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // Save current model
1726b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  void SaveWeights(SparseWeightVector<Key, Hash> *model) {
1736b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    model->LoadWeightVector(weight_);
1746b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
1756b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // Scoring
1766b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  double ScoreSample(const SparseWeightVector<Key, Hash> &sample) {
1776b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    const double dot = weight_.DotProduct(sample);
1786b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    double w_square;
1796b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    double s_square;
1806b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    switch (kernel_type_) {
1816b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      case LINEAR:
1826b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua        return dot;
1836b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      case POLY:
1846b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua        return pow(kernel_gain_ * dot + kernel_bias_, kernel_param_);
1856b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      case RBF:
1866b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua        w_square = weight_.L2Norm();
1876b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua        s_square = sample.L2Norm();
1886b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua        return exp(-1 * kernel_param_ * (w_square + s_square - 2 * dot));
1896b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      default:
1906b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      ALOGE("unsupported kernel: %d", kernel_type_);
1916b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    }
1926b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    return -1;
1936b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
1946b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // Learning Functions
1956b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // Return values:
1966b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // 1 :full update went through
1976b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // 2 :partial update went through (for SL only)
1986b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // 0 :no update necessary.
1996b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // -1:error.
2006b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  int UpdateClassifier(const SparseWeightVector<Key, Hash> &positive,
2016b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua                       const SparseWeightVector<Key, Hash> &negative);
2026b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
2036b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua private:
2046b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  SparseWeightVector<Key, Hash> weight_;
2056b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  double norm_constraint_;
2066b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  double lambda_;
2076b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  RegularizationType regularization_type_;
2086b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  AdaptationMode adaptation_mode_;
2096b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  UpdateType update_type_;
2106b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  RankLossType rank_loss_type_;
2116b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  KernelType kernel_type_;
2126b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // Kernel gain and bias are typically multiplicative and additive factors to
2136b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // the dot product while calculating the kernel function. Kernel param is
2146b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // kernel-specific. In case of polynomial kernel, it is the degree of the
2156b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // polynomial.
2166b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  double kernel_param_;
2176b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  double kernel_gain_;
2186b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  double kernel_bias_;
2196b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  double acceptence_probability_;
2206b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  SparseWeightVector<Key, Hash> current_negative_;
2216b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  LearningRateController learning_rate_controller_;
2226b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  uint64 iteration_num_;
2236b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // We average out gradient updates for mini_batch_size_ samples
2246b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // before performing an iteration of the algorithm.
2256b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  uint64 mini_batch_counter_;
2266b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  uint64 mini_batch_size_;
2276b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // Specifies the number of non-zero entries allowed in a gradient.
2286b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // Default is -1 which means we take the gradient as given by data without
2296b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // adding any new constraints. positive number is treated as an L0 constraint
2306b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  int32 gradient_l0_norm_;
2316b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // Sub-Gradient Updates
2326b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // Pure Sub-Gradient update without any reprojection
2336b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // Note that a form of L2 regularization is built into this
2346b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  void UpdateSubGradient(const SparseWeightVector<Key, Hash> &positive,
2356b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua                         const SparseWeightVector<Key, Hash> &negative,
2366b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua                         const double learning_rate,
2376b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua                         const double positive_score,
2386b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua                         const double negative_score,
2396b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua                         const int32 gradient_l0_norm);
2406b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
2416b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua};
2426b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua}  // namespace learning_stochastic_linear
2436b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#endif  // LEARNING_STOCHASTIC_LINEAR_STOCHASTIC_LINEAR_RANKER_H_
244