16b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua/* 26b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * Copyright (C) 2012 The Android Open Source Project 36b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * 46b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * Licensed under the Apache License, Version 2.0 (the "License"); 56b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * you may not use this file except in compliance with the License. 66b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * You may obtain a copy of the License at 76b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * 86b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * http://www.apache.org/licenses/LICENSE-2.0 96b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * 106b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * Unless required by applicable law or agreed to in writing, software 116b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * distributed under the License is distributed on an "AS IS" BASIS, 126b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 136b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * See the License for the specific language governing permissions and 146b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * limitations under the License. 156b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua */ 166b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 176b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua// Implements learning rate adaptations common to most stochastic algorithms. 186b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 196b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#ifndef LEARNING_STOCHASTIC_LINEAR_LEARNING_RATE_CONTROLLER_INL_H_ 206b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#define LEARNING_STOCHASTIC_LINEAR_LEARNING_RATE_CONTROLLER_INL_H_ 216b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 226b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#include <cmath> 236b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#include "common_defs.h" 246b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 256b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huanamespace learning_stochastic_linear { 266b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 276b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huaclass LearningRateController { 286b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua public: 296b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua LearningRateController() { 306b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua iteration_num_ = 1; 316b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua lambda_ = 1.0; 326b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua mini_batch_size_ = 1; 336b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua mini_batch_counter_ = 1; 346b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua sample_num_ = 1; 356b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua mode_ = INV_LINEAR; 366b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua is_first_sample_ = true; 376b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 386b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua ~LearningRateController() {} 396b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // Getters and Setters for learning rate parameter lambda_ 406b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua double GetLambda() const { 416b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return lambda_; 426b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 436b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua void SetLambda(double lambda) { 446b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua lambda_ = lambda; 456b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 466b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // Operations on current iteration number 476b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua void SetIterationNumber(uint64 num) { 486b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua iteration_num_ = num; 496b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 506b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua void IncrementIteration() { 516b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua ++iteration_num_; 526b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 536b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua uint64 GetIterationNumber() const { 546b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return iteration_num_; 556b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 566b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // Mini batch operations 576b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua uint64 GetMiniBatchSize() const { 586b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return mini_batch_size_; 596b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 606b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua void SetMiniBatchSize(uint64 size) { 616b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua //CHECK_GT(size, 0); 626b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua mini_batch_size_ = size; 636b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 646b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua void IncrementSample() { 656b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // If this is the first sample we've already counted it to prevent NaNs 666b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // in the learning rate computation 676b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua if (is_first_sample_) { 686b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua is_first_sample_ = false; 696b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return; 706b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 716b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua ++sample_num_; 726b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua if (1 == mini_batch_size_) { 736b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua IncrementIteration(); 746b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua mini_batch_counter_ = 0; 756b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } else { 766b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua ++mini_batch_counter_; 776b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua if ((mini_batch_counter_ % mini_batch_size_ == 0)) { 786b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua IncrementIteration(); 796b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua mini_batch_counter_ = 0; 806b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 816b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 826b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 836b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua uint64 GetMiniBatchCounter() const { 846b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return mini_batch_counter_; 856b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 866b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // Getters and setters for adaptation mode 876b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua AdaptationMode GetAdaptationMode() const { 886b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return mode_; 896b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 906b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua void SetAdaptationMode(AdaptationMode m) { 916b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua mode_ = m; 926b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 936b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua double GetLearningRate() const { 946b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua if (mode_ == CONST) { 956b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return (1.0 / (lambda_ * mini_batch_size_)); 966b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } else if (mode_ == INV_LINEAR) { 976b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return (1.0 / (lambda_ * iteration_num_ * mini_batch_size_)); 986b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } else if (mode_ == INV_QUADRATIC) { 996b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return (1.0 / (lambda_ * 1006b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua mini_batch_size_ * 1016b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua (static_cast<double>(iteration_num_) * iteration_num_))); 1026b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } else if (mode_ == INV_SQRT) { 1036b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return (1.0 / (lambda_ * 1046b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua mini_batch_size_ * 1056b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua sqrt((double)iteration_num_))); 1066b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1076b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return 0; 1086b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1096b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua void CopyFrom(const LearningRateController &other) { 1106b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua iteration_num_ = other.iteration_num_; 1116b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua sample_num_ = other.sample_num_; 1126b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua mini_batch_size_ = other.mini_batch_size_; 1136b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua mini_batch_counter_ = other.mini_batch_counter_; 1146b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua mode_ = other.mode_; 1156b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua is_first_sample_ = other.is_first_sample_; 1166b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1176b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua private: 1186b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua uint64 iteration_num_; 1196b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua uint64 sample_num_; 1206b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua uint64 mini_batch_size_; 1216b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua uint64 mini_batch_counter_; 1226b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua double lambda_; 1236b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua AdaptationMode mode_; 1246b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua bool is_first_sample_; 1256b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua}; 1266b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua} // namespace learning_stochastic_linear 1276b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#endif // LEARNING_STOCHASTIC_LINEAR_LEARNING_RATE_CONTROLLER_INL_H_ 128