stochastic_linear_ranker.cpp revision a08525ea290ff4edc766eda1ec80388be866a79e
16b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua/* 26b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * Copyright (C) 2012 The Android Open Source Project 36b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * 46b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * Licensed under the Apache License, Version 2.0 (the "License"); 56b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * you may not use this file except in compliance with the License. 66b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * You may obtain a copy of the License at 76b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * 86b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * http://www.apache.org/licenses/LICENSE-2.0 96b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * 106b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * Unless required by applicable law or agreed to in writing, software 116b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * distributed under the License is distributed on an "AS IS" BASIS, 126b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 136b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * See the License for the specific language governing permissions and 146b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * limitations under the License. 156b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua */ 166b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 176b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#include <algorithm> 186b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#include <stdlib.h> 196b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 206b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#include "stochastic_linear_ranker.h" 216b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 226b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huanamespace learning_stochastic_linear { 236b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 246b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huatemplate<class Key, class Hash> 256b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huavoid StochasticLinearRanker<Key, Hash>::UpdateSubGradient( 266b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const SparseWeightVector<Key, Hash> &positive, 276b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const SparseWeightVector<Key, Hash> &negative, 286b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const double learning_rate, 296b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const double positive_score, 306b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const double negative_score, 316b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const int32 gradient_l0_norm) { 326b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua SparseWeightVector<Key, Hash> gradient; 336b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua double final_learning_rate; 346b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua gradient.AdditiveWeightUpdate(1.0, positive, 0.0); 356b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua gradient.AdditiveWeightUpdate(-1.0, negative, 0.0); 366b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua if (update_type_ == FULL_CS || update_type_ == REG_CS) { 376b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const double loss = std::max(0.0, (1 - positive_score + negative_score)); 386b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const double gradient_norm = gradient.L2Norm(); 396b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const double kMinGradientNorm = 1e-8; 406b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const double kMaxGradientNorm = 1e8; 416b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua if (gradient_norm < kMinGradientNorm || gradient_norm > kMaxGradientNorm) 426b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return; 436b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua if (update_type_ == FULL_CS) 446b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua final_learning_rate = 456b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua std::min(lambda_, loss / (gradient_norm * gradient_norm)); 466b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua else 476b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua final_learning_rate = 486b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua loss / (gradient_norm * gradient_norm + 1 / (2 * lambda_)); 496b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } else { 506b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua gradient.AdditiveWeightUpdate(-lambda_, weight_, 0.0); 516b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua final_learning_rate = learning_rate; 526b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 536b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua if (gradient_l0_norm > 0) { 546b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua gradient.ReprojectL0(gradient_l0_norm); 556b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 566b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 576b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua if (gradient.IsValid()) 586b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua weight_.AdditiveWeightUpdate(final_learning_rate, gradient, 0.0); 596b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua} 606b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 616b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huatemplate<class Key, class Hash> 626b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huaint StochasticLinearRanker<Key, Hash>::UpdateClassifier( 636b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const SparseWeightVector<Key, Hash> &positive, 646b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const SparseWeightVector<Key, Hash> &negative) { 656b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // Create a backup of the weight vector in case the iteration results in 666b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // unbounded weights. 676b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua SparseWeightVector<Key, Hash> weight_backup; 686b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua weight_backup.CopyFrom(weight_); 696b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 706b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const double positive_score = ScoreSample(positive); 716b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const double negative_score = ScoreSample(negative); 726b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua if ((positive_score - negative_score) < 1) { 736b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua ++mini_batch_counter_; 746b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua if ((mini_batch_counter_ % mini_batch_size_ == 0) || 756b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua (iteration_num_ == 0)) { 766b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua ++iteration_num_; 776b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua mini_batch_counter_ = 0; 786b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 796b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua learning_rate_controller_.IncrementSample(); 806b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua double learning_rate = learning_rate_controller_.GetLearningRate(); 816b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 826b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua if (rank_loss_type_ == PAIRWISE) { 836b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua UpdateSubGradient(positive, negative, learning_rate, 846b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua positive_score, negative_score, 856b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua gradient_l0_norm_); 866b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } else if (rank_loss_type_ == RECIPROCAL_RANK) { 876b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const double current_negative_score = ScoreSample(current_negative_); 886b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua if ((negative_score > current_negative_score) || 896b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua ((rand()/RAND_MAX) < acceptence_probability_)) { 906b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua UpdateSubGradient(positive, negative, learning_rate, 916b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua positive_score, negative_score, 926b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua gradient_l0_norm_); 936b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua current_negative_.Clear(); 946b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua current_negative_.LoadWeightVector(negative); 956b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } else { 966b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua UpdateSubGradient(positive, current_negative_, learning_rate, 976b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua positive_score, negative_score, 986b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua gradient_l0_norm_); 996b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1006b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } else { 1016b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua ALOGE("Unknown rank loss type: %d", rank_loss_type_); 1026b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1036b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 1046b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua int return_code; 1056b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua if ((mini_batch_counter_ == 0) && (update_type_ == SL)) { 1066b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return_code = 1; 1076b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua switch (regularization_type_) { 1086b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua case L1: 1096b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua weight_.ReprojectL1(norm_constraint_); 1106b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua break; 1116b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua case L2: 1126b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua weight_.ReprojectL2(norm_constraint_); 1136b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua break; 1146b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua case L0: 1156b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua weight_.ReprojectL0(norm_constraint_); 1166b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua break; 1176b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua default: 1186b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua ALOGE("Unsupported optimization type specified"); 1196b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return_code = -1; 1206b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1216b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } else if (update_type_ == SL) { 1226b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return_code = 2; 1236b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } else { 1246b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return_code = 1; 1256b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1266b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 1276b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua if (!weight_.IsValid()) 1286b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua weight_.CopyFrom(weight_backup); 1296b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return return_code; 1306b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1316b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 1326b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return 0; 1336b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua} 1346b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 135a08525ea290ff4edc766eda1ec80388be866a79eDan Alberttemplate class StochasticLinearRanker<std::string, std::unordered_map<std::string, double> >; 136a08525ea290ff4edc766eda1ec80388be866a79eDan Alberttemplate class StochasticLinearRanker<int, std::unordered_map<int, double> >; 137a08525ea290ff4edc766eda1ec80388be866a79eDan Alberttemplate class StochasticLinearRanker<uint64, std::unordered_map<uint64, double> >; 1386b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 1396b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua} // namespace learning_stochastic_linear 140