stochastic_linear_ranker.h revision a08525ea290ff4edc766eda1ec80388be866a79e
16b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua/* 26b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * Copyright (C) 2012 The Android Open Source Project 36b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * 46b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * Licensed under the Apache License, Version 2.0 (the "License"); 56b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * you may not use this file except in compliance with the License. 66b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * You may obtain a copy of the License at 76b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * 86b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * http://www.apache.org/licenses/LICENSE-2.0 96b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * 106b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * Unless required by applicable law or agreed to in writing, software 116b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * distributed under the License is distributed on an "AS IS" BASIS, 126b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 136b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * See the License for the specific language governing permissions and 146b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * limitations under the License. 156b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua */ 166b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 176b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua// Stochastic Linear Ranking algorithms. 186b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua// This class will implement a set of incremental algorithms for ranking tasks 196b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua// They support both L1 and L2 regularizations. 206b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 216b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 226b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#ifndef LEARNING_STOCHASTIC_LINEAR_STOCHASTIC_LINEAR_RANKER_H_ 236b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#define LEARNING_STOCHASTIC_LINEAR_STOCHASTIC_LINEAR_RANKER_H_ 246b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 25a08525ea290ff4edc766eda1ec80388be866a79eDan Albert#include <sys/types.h> 26a08525ea290ff4edc766eda1ec80388be866a79eDan Albert 276b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#include <cmath> 286b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#include <string> 29a08525ea290ff4edc766eda1ec80388be866a79eDan Albert#include <unordered_map> 306b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 316b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#include "cutils/log.h" 326b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#include "common_defs.h" 336b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#include "learning_rate_controller-inl.h" 346b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#include "sparse_weight_vector.h" 356b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 366b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huanamespace learning_stochastic_linear { 376b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 386b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua// NOTE: This Stochastic Linear Ranker supports only the following update types: 396b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua// SL: Stochastic Linear 406b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua// CS: Constraint Satisfaction 41a08525ea290ff4edc766eda1ec80388be866a79eDan Alberttemplate<class Key = std::string, class Hash = std::unordered_map<std::string, double> > 426b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huaclass StochasticLinearRanker { 436b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua public: 446b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // initialize lambda_ and constraint to a meaningful default. Will give 456b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // equal weight to the error and regularizer. 466b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua StochasticLinearRanker() { 476b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua iteration_num_ = 0; 486b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua lambda_ = 1.0; 496b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua learning_rate_controller_.SetLambda(lambda_); 506b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua mini_batch_size_ = 1; 516b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua learning_rate_controller_.SetMiniBatchSize(mini_batch_size_); 52b019e89cbea221598c482b05ab68b7660b41aa23saberian adaptation_mode_ = INV_LINEAR; 53b019e89cbea221598c482b05ab68b7660b41aa23saberian learning_rate_controller_.SetAdaptationMode(adaptation_mode_); 546b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua update_type_ = SL; 55b019e89cbea221598c482b05ab68b7660b41aa23saberian regularization_type_ = L2; 566b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua kernel_type_ = LINEAR; 576b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua kernel_param_ = 1.0; 586b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua kernel_gain_ = 1.0; 596b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua kernel_bias_ = 0.0; 606b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua rank_loss_type_ = PAIRWISE; 616b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua acceptence_probability_ = 0.1; 626b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua mini_batch_counter_ = 0; 636b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua gradient_l0_norm_ = -1; 646b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua norm_constraint_ = 1.0; 656b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 666b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 676b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua ~StochasticLinearRanker() {} 686b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // Getters and setters 696b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua double GetIterationNumber() const { 706b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return iteration_num_; 716b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 726b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua double GetNormContraint() const { 736b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return norm_constraint_; 746b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 756b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua RegularizationType GetRegularizationType() const { 766b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return regularization_type_; 776b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 786b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua double GetLambda() const { 796b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return lambda_; 806b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 816b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua uint64 GetMiniBatchSize() const { 826b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return mini_batch_size_; 836b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 846b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua int32 GetGradientL0Norm() const { 856b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return gradient_l0_norm_; 866b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 876b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua UpdateType GetUpdateType() const { 886b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return update_type_; 896b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 906b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua AdaptationMode GetAdaptationMode() const { 916b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return adaptation_mode_; 926b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 93b019e89cbea221598c482b05ab68b7660b41aa23saberian KernelType GetKernelType() const { 94b019e89cbea221598c482b05ab68b7660b41aa23saberian return kernel_type_; 95b019e89cbea221598c482b05ab68b7660b41aa23saberian } 96b019e89cbea221598c482b05ab68b7660b41aa23saberian // This function returns the basic kernel parameter. In case of 976b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // polynomial kernel, it implies the degree of the polynomial. In case of 986b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // RBF kernel, it implies the sigma parameter. In case of linear kernel, 996b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // it is not used. 100b019e89cbea221598c482b05ab68b7660b41aa23saberian double GetKernelParam() const { 101b019e89cbea221598c482b05ab68b7660b41aa23saberian return kernel_param_; 102b019e89cbea221598c482b05ab68b7660b41aa23saberian } 103b019e89cbea221598c482b05ab68b7660b41aa23saberian double GetKernelGain() const { 104b019e89cbea221598c482b05ab68b7660b41aa23saberian return kernel_gain_;; 105b019e89cbea221598c482b05ab68b7660b41aa23saberian } 106b019e89cbea221598c482b05ab68b7660b41aa23saberian double GetKernelBias() const { 107b019e89cbea221598c482b05ab68b7660b41aa23saberian return kernel_bias_; 1086b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1096b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua RankLossType GetRankLossType() const { 1106b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return rank_loss_type_; 1116b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1126b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua double GetAcceptanceProbability() const { 1136b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return acceptence_probability_; 1146b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1156b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua void SetIterationNumber(uint64 num) { 1166b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua iteration_num_=num; 1176b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1186b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua void SetNormConstraint(const double norm) { 1196b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua norm_constraint_ = norm; 1206b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1216b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua void SetRegularizationType(const RegularizationType r) { 1226b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua regularization_type_ = r; 1236b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1246b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua void SetLambda(double l) { 1256b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua lambda_ = l; 1266b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua learning_rate_controller_.SetLambda(l); 1276b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1286b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua void SetMiniBatchSize(const uint64 msize) { 1296b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua mini_batch_size_ = msize; 1306b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua learning_rate_controller_.SetMiniBatchSize(msize); 1316b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1326b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua void SetAdaptationMode(AdaptationMode m) { 1336b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua adaptation_mode_ = m; 1346b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua learning_rate_controller_.SetAdaptationMode(m); 1356b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 136b019e89cbea221598c482b05ab68b7660b41aa23saberian void SetKernelType(KernelType k ) { 137b019e89cbea221598c482b05ab68b7660b41aa23saberian kernel_type_ = k; 138b019e89cbea221598c482b05ab68b7660b41aa23saberian } 139b019e89cbea221598c482b05ab68b7660b41aa23saberian // This function sets the basic kernel parameter. In case of 1406b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // polynomial kernel, it implies the degree of the polynomial. In case of 1416b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // RBF kernel, it implies the sigma parameter. In case of linear kernel, 1426b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // it is not used. 143b019e89cbea221598c482b05ab68b7660b41aa23saberian void SetKernelParam(double param) { 1446b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua kernel_param_ = param; 145b019e89cbea221598c482b05ab68b7660b41aa23saberian } 146b019e89cbea221598c482b05ab68b7660b41aa23saberian // This function sets the kernel gain. NOTE: in most use cases, gain should 147b019e89cbea221598c482b05ab68b7660b41aa23saberian // be set to 1.0. 148b019e89cbea221598c482b05ab68b7660b41aa23saberian void SetKernelGain(double gain) { 1496b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua kernel_gain_ = gain; 150b019e89cbea221598c482b05ab68b7660b41aa23saberian } 151b019e89cbea221598c482b05ab68b7660b41aa23saberian // This function sets the kernel bias. NOTE: in most use cases, bias should 152b019e89cbea221598c482b05ab68b7660b41aa23saberian // be set to 0.0. 153b019e89cbea221598c482b05ab68b7660b41aa23saberian void SetKernelBias(double bias) { 1546b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua kernel_bias_ = bias; 1556b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1566b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua void SetUpdateType(UpdateType u) { 1576b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua update_type_ = u; 1586b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1596b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua void SetRankLossType(RankLossType r) { 1606b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua rank_loss_type_ = r; 1616b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1626b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua void SetAcceptanceProbability(double p) { 1636b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua acceptence_probability_ = p; 1646b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1656b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua void SetGradientL0Norm(const int32 gradient_l0_norm) { 1666b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua gradient_l0_norm_ = gradient_l0_norm; 1676b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1686b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // Load an existing model 1696b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua void LoadWeights(const SparseWeightVector<Key, Hash> &model) { 1706b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua weight_.LoadWeightVector(model); 1716b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1726b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // Save current model 1736b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua void SaveWeights(SparseWeightVector<Key, Hash> *model) { 1746b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua model->LoadWeightVector(weight_); 1756b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1766b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // Scoring 1776b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua double ScoreSample(const SparseWeightVector<Key, Hash> &sample) { 1786b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const double dot = weight_.DotProduct(sample); 1796b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua double w_square; 1806b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua double s_square; 1816b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua switch (kernel_type_) { 1826b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua case LINEAR: 1836b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return dot; 1846b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua case POLY: 1856b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return pow(kernel_gain_ * dot + kernel_bias_, kernel_param_); 1866b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua case RBF: 1876b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua w_square = weight_.L2Norm(); 1886b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua s_square = sample.L2Norm(); 1896b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return exp(-1 * kernel_param_ * (w_square + s_square - 2 * dot)); 1906b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua default: 1916b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua ALOGE("unsupported kernel: %d", kernel_type_); 1926b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1936b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return -1; 1946b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1956b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // Learning Functions 1966b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // Return values: 1976b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // 1 :full update went through 1986b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // 2 :partial update went through (for SL only) 1996b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // 0 :no update necessary. 2006b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // -1:error. 2016b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua int UpdateClassifier(const SparseWeightVector<Key, Hash> &positive, 2026b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const SparseWeightVector<Key, Hash> &negative); 2036b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 2046b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua private: 2056b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua SparseWeightVector<Key, Hash> weight_; 2066b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua double norm_constraint_; 2076b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua double lambda_; 2086b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua RegularizationType regularization_type_; 2096b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua AdaptationMode adaptation_mode_; 2106b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua UpdateType update_type_; 2116b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua RankLossType rank_loss_type_; 2126b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua KernelType kernel_type_; 2136b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // Kernel gain and bias are typically multiplicative and additive factors to 2146b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // the dot product while calculating the kernel function. Kernel param is 2156b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // kernel-specific. In case of polynomial kernel, it is the degree of the 2166b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // polynomial. 2176b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua double kernel_param_; 2186b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua double kernel_gain_; 2196b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua double kernel_bias_; 2206b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua double acceptence_probability_; 2216b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua SparseWeightVector<Key, Hash> current_negative_; 2226b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua LearningRateController learning_rate_controller_; 2236b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua uint64 iteration_num_; 2246b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // We average out gradient updates for mini_batch_size_ samples 2256b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // before performing an iteration of the algorithm. 2266b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua uint64 mini_batch_counter_; 2276b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua uint64 mini_batch_size_; 2286b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // Specifies the number of non-zero entries allowed in a gradient. 2296b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // Default is -1 which means we take the gradient as given by data without 2306b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // adding any new constraints. positive number is treated as an L0 constraint 2316b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua int32 gradient_l0_norm_; 2326b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // Sub-Gradient Updates 2336b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // Pure Sub-Gradient update without any reprojection 2346b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // Note that a form of L2 regularization is built into this 2356b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua void UpdateSubGradient(const SparseWeightVector<Key, Hash> &positive, 2366b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const SparseWeightVector<Key, Hash> &negative, 2376b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const double learning_rate, 2386b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const double positive_score, 2396b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const double negative_score, 2406b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const int32 gradient_l0_norm); 2416b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 2426b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua}; 2436b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua} // namespace learning_stochastic_linear 2446b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#endif // LEARNING_STOCHASTIC_LINEAR_STOCHASTIC_LINEAR_RANKER_H_ 245