stochastic_linear_ranker.h revision b019e89cbea221598c482b05ab68b7660b41aa23
16b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua/* 26b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * Copyright (C) 2012 The Android Open Source Project 36b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * 46b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * Licensed under the Apache License, Version 2.0 (the "License"); 56b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * you may not use this file except in compliance with the License. 66b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * You may obtain a copy of the License at 76b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * 86b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * http://www.apache.org/licenses/LICENSE-2.0 96b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * 106b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * Unless required by applicable law or agreed to in writing, software 116b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * distributed under the License is distributed on an "AS IS" BASIS, 126b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 136b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * See the License for the specific language governing permissions and 146b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * limitations under the License. 156b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua */ 166b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 176b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua// Stochastic Linear Ranking algorithms. 186b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua// This class will implement a set of incremental algorithms for ranking tasks 196b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua// They support both L1 and L2 regularizations. 206b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 216b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 226b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#ifndef LEARNING_STOCHASTIC_LINEAR_STOCHASTIC_LINEAR_RANKER_H_ 236b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#define LEARNING_STOCHASTIC_LINEAR_STOCHASTIC_LINEAR_RANKER_H_ 246b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 256b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#include <cmath> 266b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#include <hash_map> 276b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#include <string> 286b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 296b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#include <sys/types.h> 306b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#include "cutils/log.h" 316b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#include "common_defs.h" 326b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#include "learning_rate_controller-inl.h" 336b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#include "sparse_weight_vector.h" 346b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 356b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huanamespace learning_stochastic_linear { 366b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 376b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua// NOTE: This Stochastic Linear Ranker supports only the following update types: 386b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua// SL: Stochastic Linear 396b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua// CS: Constraint Satisfaction 406b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huatemplate<class Key = std::string, class Hash = std::hash_map<std::string, double> > 416b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huaclass StochasticLinearRanker { 426b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua public: 436b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // initialize lambda_ and constraint to a meaningful default. Will give 446b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // equal weight to the error and regularizer. 456b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua StochasticLinearRanker() { 466b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua iteration_num_ = 0; 476b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua lambda_ = 1.0; 486b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua learning_rate_controller_.SetLambda(lambda_); 496b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua mini_batch_size_ = 1; 506b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua learning_rate_controller_.SetMiniBatchSize(mini_batch_size_); 51b019e89cbea221598c482b05ab68b7660b41aa23saberian adaptation_mode_ = INV_LINEAR; 52b019e89cbea221598c482b05ab68b7660b41aa23saberian learning_rate_controller_.SetAdaptationMode(adaptation_mode_); 536b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua update_type_ = SL; 54b019e89cbea221598c482b05ab68b7660b41aa23saberian regularization_type_ = L2; 556b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua kernel_type_ = LINEAR; 566b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua kernel_param_ = 1.0; 576b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua kernel_gain_ = 1.0; 586b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua kernel_bias_ = 0.0; 596b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua rank_loss_type_ = PAIRWISE; 606b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua acceptence_probability_ = 0.1; 616b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua mini_batch_counter_ = 0; 626b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua gradient_l0_norm_ = -1; 636b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua norm_constraint_ = 1.0; 646b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 656b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 666b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua ~StochasticLinearRanker() {} 676b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // Getters and setters 686b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua double GetIterationNumber() const { 696b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return iteration_num_; 706b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 716b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua double GetNormContraint() const { 726b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return norm_constraint_; 736b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 746b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua RegularizationType GetRegularizationType() const { 756b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return regularization_type_; 766b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 776b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua double GetLambda() const { 786b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return lambda_; 796b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 806b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua uint64 GetMiniBatchSize() const { 816b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return mini_batch_size_; 826b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 836b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua int32 GetGradientL0Norm() const { 846b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return gradient_l0_norm_; 856b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 866b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua UpdateType GetUpdateType() const { 876b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return update_type_; 886b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 896b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua AdaptationMode GetAdaptationMode() const { 906b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return adaptation_mode_; 916b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 92b019e89cbea221598c482b05ab68b7660b41aa23saberian KernelType GetKernelType() const { 93b019e89cbea221598c482b05ab68b7660b41aa23saberian return kernel_type_; 94b019e89cbea221598c482b05ab68b7660b41aa23saberian } 95b019e89cbea221598c482b05ab68b7660b41aa23saberian // This function returns the basic kernel parameter. In case of 966b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // polynomial kernel, it implies the degree of the polynomial. In case of 976b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // RBF kernel, it implies the sigma parameter. In case of linear kernel, 986b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // it is not used. 99b019e89cbea221598c482b05ab68b7660b41aa23saberian double GetKernelParam() const { 100b019e89cbea221598c482b05ab68b7660b41aa23saberian return kernel_param_; 101b019e89cbea221598c482b05ab68b7660b41aa23saberian } 102b019e89cbea221598c482b05ab68b7660b41aa23saberian double GetKernelGain() const { 103b019e89cbea221598c482b05ab68b7660b41aa23saberian return kernel_gain_;; 104b019e89cbea221598c482b05ab68b7660b41aa23saberian } 105b019e89cbea221598c482b05ab68b7660b41aa23saberian double GetKernelBias() const { 106b019e89cbea221598c482b05ab68b7660b41aa23saberian return kernel_bias_; 1076b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1086b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua RankLossType GetRankLossType() const { 1096b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return rank_loss_type_; 1106b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1116b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua double GetAcceptanceProbability() const { 1126b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return acceptence_probability_; 1136b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1146b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua void SetIterationNumber(uint64 num) { 1156b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua iteration_num_=num; 1166b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1176b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua void SetNormConstraint(const double norm) { 1186b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua norm_constraint_ = norm; 1196b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1206b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua void SetRegularizationType(const RegularizationType r) { 1216b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua regularization_type_ = r; 1226b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1236b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua void SetLambda(double l) { 1246b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua lambda_ = l; 1256b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua learning_rate_controller_.SetLambda(l); 1266b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1276b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua void SetMiniBatchSize(const uint64 msize) { 1286b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua mini_batch_size_ = msize; 1296b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua learning_rate_controller_.SetMiniBatchSize(msize); 1306b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1316b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua void SetAdaptationMode(AdaptationMode m) { 1326b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua adaptation_mode_ = m; 1336b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua learning_rate_controller_.SetAdaptationMode(m); 1346b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 135b019e89cbea221598c482b05ab68b7660b41aa23saberian void SetKernelType(KernelType k ) { 136b019e89cbea221598c482b05ab68b7660b41aa23saberian kernel_type_ = k; 137b019e89cbea221598c482b05ab68b7660b41aa23saberian } 138b019e89cbea221598c482b05ab68b7660b41aa23saberian // This function sets the basic kernel parameter. In case of 1396b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // polynomial kernel, it implies the degree of the polynomial. In case of 1406b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // RBF kernel, it implies the sigma parameter. In case of linear kernel, 1416b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // it is not used. 142b019e89cbea221598c482b05ab68b7660b41aa23saberian void SetKernelParam(double param) { 1436b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua kernel_param_ = param; 144b019e89cbea221598c482b05ab68b7660b41aa23saberian } 145b019e89cbea221598c482b05ab68b7660b41aa23saberian // This function sets the kernel gain. NOTE: in most use cases, gain should 146b019e89cbea221598c482b05ab68b7660b41aa23saberian // be set to 1.0. 147b019e89cbea221598c482b05ab68b7660b41aa23saberian void SetKernelGain(double gain) { 1486b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua kernel_gain_ = gain; 149b019e89cbea221598c482b05ab68b7660b41aa23saberian } 150b019e89cbea221598c482b05ab68b7660b41aa23saberian // This function sets the kernel bias. NOTE: in most use cases, bias should 151b019e89cbea221598c482b05ab68b7660b41aa23saberian // be set to 0.0. 152b019e89cbea221598c482b05ab68b7660b41aa23saberian void SetKernelBias(double bias) { 1536b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua kernel_bias_ = bias; 1546b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1556b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua void SetUpdateType(UpdateType u) { 1566b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua update_type_ = u; 1576b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1586b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua void SetRankLossType(RankLossType r) { 1596b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua rank_loss_type_ = r; 1606b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1616b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua void SetAcceptanceProbability(double p) { 1626b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua acceptence_probability_ = p; 1636b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1646b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua void SetGradientL0Norm(const int32 gradient_l0_norm) { 1656b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua gradient_l0_norm_ = gradient_l0_norm; 1666b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1676b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // Load an existing model 1686b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua void LoadWeights(const SparseWeightVector<Key, Hash> &model) { 1696b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua weight_.LoadWeightVector(model); 1706b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1716b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // Save current model 1726b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua void SaveWeights(SparseWeightVector<Key, Hash> *model) { 1736b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua model->LoadWeightVector(weight_); 1746b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1756b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // Scoring 1766b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua double ScoreSample(const SparseWeightVector<Key, Hash> &sample) { 1776b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const double dot = weight_.DotProduct(sample); 1786b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua double w_square; 1796b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua double s_square; 1806b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua switch (kernel_type_) { 1816b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua case LINEAR: 1826b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return dot; 1836b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua case POLY: 1846b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return pow(kernel_gain_ * dot + kernel_bias_, kernel_param_); 1856b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua case RBF: 1866b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua w_square = weight_.L2Norm(); 1876b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua s_square = sample.L2Norm(); 1886b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return exp(-1 * kernel_param_ * (w_square + s_square - 2 * dot)); 1896b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua default: 1906b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua ALOGE("unsupported kernel: %d", kernel_type_); 1916b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1926b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return -1; 1936b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1946b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // Learning Functions 1956b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // Return values: 1966b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // 1 :full update went through 1976b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // 2 :partial update went through (for SL only) 1986b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // 0 :no update necessary. 1996b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // -1:error. 2006b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua int UpdateClassifier(const SparseWeightVector<Key, Hash> &positive, 2016b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const SparseWeightVector<Key, Hash> &negative); 2026b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 2036b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua private: 2046b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua SparseWeightVector<Key, Hash> weight_; 2056b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua double norm_constraint_; 2066b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua double lambda_; 2076b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua RegularizationType regularization_type_; 2086b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua AdaptationMode adaptation_mode_; 2096b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua UpdateType update_type_; 2106b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua RankLossType rank_loss_type_; 2116b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua KernelType kernel_type_; 2126b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // Kernel gain and bias are typically multiplicative and additive factors to 2136b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // the dot product while calculating the kernel function. Kernel param is 2146b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // kernel-specific. In case of polynomial kernel, it is the degree of the 2156b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // polynomial. 2166b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua double kernel_param_; 2176b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua double kernel_gain_; 2186b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua double kernel_bias_; 2196b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua double acceptence_probability_; 2206b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua SparseWeightVector<Key, Hash> current_negative_; 2216b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua LearningRateController learning_rate_controller_; 2226b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua uint64 iteration_num_; 2236b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // We average out gradient updates for mini_batch_size_ samples 2246b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // before performing an iteration of the algorithm. 2256b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua uint64 mini_batch_counter_; 2266b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua uint64 mini_batch_size_; 2276b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // Specifies the number of non-zero entries allowed in a gradient. 2286b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // Default is -1 which means we take the gradient as given by data without 2296b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // adding any new constraints. positive number is treated as an L0 constraint 2306b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua int32 gradient_l0_norm_; 2316b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // Sub-Gradient Updates 2326b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // Pure Sub-Gradient update without any reprojection 2336b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // Note that a form of L2 regularization is built into this 2346b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua void UpdateSubGradient(const SparseWeightVector<Key, Hash> &positive, 2356b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const SparseWeightVector<Key, Hash> &negative, 2366b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const double learning_rate, 2376b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const double positive_score, 2386b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const double negative_score, 2396b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const int32 gradient_l0_norm); 2406b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 2416b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua}; 2426b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua} // namespace learning_stochastic_linear 2436b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#endif // LEARNING_STOCHASTIC_LINEAR_STOCHASTIC_LINEAR_RANKER_H_ 244