16b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua/*
26b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * Copyright (C) 2012 The Android Open Source Project
36b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua *
46b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * Licensed under the Apache License, Version 2.0 (the "License");
56b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * you may not use this file except in compliance with the License.
66b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * You may obtain a copy of the License at
76b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua *
86b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua *      http://www.apache.org/licenses/LICENSE-2.0
96b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua *
106b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * Unless required by applicable law or agreed to in writing, software
116b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * distributed under the License is distributed on an "AS IS" BASIS,
126b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
136b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * See the License for the specific language governing permissions and
146b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * limitations under the License.
156b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua */
166b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
176b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#include <algorithm>
186b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#include <stdlib.h>
196b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
206b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#include "stochastic_linear_ranker.h"
216b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
226b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huanamespace learning_stochastic_linear {
236b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
246b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huatemplate<class Key, class Hash>
256b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huavoid StochasticLinearRanker<Key, Hash>::UpdateSubGradient(
266b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    const SparseWeightVector<Key, Hash> &positive,
276b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    const SparseWeightVector<Key, Hash> &negative,
286b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    const double learning_rate,
296b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    const double positive_score,
306b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    const double negative_score,
316b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    const int32 gradient_l0_norm) {
326b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  SparseWeightVector<Key, Hash> gradient;
336b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  double final_learning_rate;
346b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  gradient.AdditiveWeightUpdate(1.0, positive, 0.0);
356b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  gradient.AdditiveWeightUpdate(-1.0, negative, 0.0);
366b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  if (update_type_ == FULL_CS || update_type_ == REG_CS) {
376b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    const double loss = std::max(0.0, (1 - positive_score + negative_score));
386b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    const double gradient_norm = gradient.L2Norm();
396b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    const double kMinGradientNorm = 1e-8;
406b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    const double kMaxGradientNorm = 1e8;
416b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    if (gradient_norm < kMinGradientNorm || gradient_norm > kMaxGradientNorm)
426b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      return;
436b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    if (update_type_ == FULL_CS)
446b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      final_learning_rate =
456b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua          std::min(lambda_, loss / (gradient_norm * gradient_norm));
466b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    else
476b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      final_learning_rate =
486b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua          loss / (gradient_norm * gradient_norm + 1 / (2 * lambda_));
496b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  } else {
506b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    gradient.AdditiveWeightUpdate(-lambda_, weight_, 0.0);
516b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    final_learning_rate = learning_rate;
526b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
536b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  if (gradient_l0_norm > 0) {
546b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    gradient.ReprojectL0(gradient_l0_norm);
556b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
566b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
576b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  if (gradient.IsValid())
586b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    weight_.AdditiveWeightUpdate(final_learning_rate, gradient, 0.0);
596b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua}
606b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
616b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huatemplate<class Key, class Hash>
626b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huaint StochasticLinearRanker<Key, Hash>::UpdateClassifier(
636b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    const SparseWeightVector<Key, Hash> &positive,
646b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    const SparseWeightVector<Key, Hash> &negative) {
656b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // Create a backup of the weight vector in case the iteration results in
666b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // unbounded weights.
676b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  SparseWeightVector<Key, Hash> weight_backup;
686b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  weight_backup.CopyFrom(weight_);
696b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
706b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  const double positive_score = ScoreSample(positive);
716b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  const double negative_score = ScoreSample(negative);
726b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  if ((positive_score - negative_score) < 1) {
736b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    ++mini_batch_counter_;
746b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    if ((mini_batch_counter_ % mini_batch_size_ == 0) ||
756b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua        (iteration_num_ == 0)) {
766b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      ++iteration_num_;
776b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      mini_batch_counter_ = 0;
786b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    }
796b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    learning_rate_controller_.IncrementSample();
806b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    double learning_rate = learning_rate_controller_.GetLearningRate();
816b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
826b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    if (rank_loss_type_ == PAIRWISE) {
836b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      UpdateSubGradient(positive, negative, learning_rate,
846b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua                        positive_score, negative_score,
856b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua                        gradient_l0_norm_);
866b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    } else if (rank_loss_type_ == RECIPROCAL_RANK) {
876b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      const double current_negative_score = ScoreSample(current_negative_);
886b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      if ((negative_score > current_negative_score) ||
896b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua          ((rand()/RAND_MAX) < acceptence_probability_)) {
906b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua        UpdateSubGradient(positive, negative, learning_rate,
916b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua                          positive_score, negative_score,
926b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua                          gradient_l0_norm_);
936b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua        current_negative_.Clear();
946b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua        current_negative_.LoadWeightVector(negative);
956b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      } else {
966b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua        UpdateSubGradient(positive, current_negative_, learning_rate,
976b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua                          positive_score, negative_score,
986b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua                          gradient_l0_norm_);
996b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      }
1006b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    } else {
1016b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      ALOGE("Unknown rank loss type: %d", rank_loss_type_);
1026b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    }
1036b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
1046b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    int return_code;
1056b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    if ((mini_batch_counter_ == 0) && (update_type_ == SL)) {
1066b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      return_code = 1;
1076b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      switch (regularization_type_) {
1086b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua        case L1:
1096b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua          weight_.ReprojectL1(norm_constraint_);
1106b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua          break;
1116b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua        case L2:
1126b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua          weight_.ReprojectL2(norm_constraint_);
1136b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua          break;
1146b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua        case L0:
1156b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua          weight_.ReprojectL0(norm_constraint_);
1166b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua          break;
1176b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua        default:
1186b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua          ALOGE("Unsupported optimization type specified");
1196b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua          return_code = -1;
1206b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      }
1216b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    } else if (update_type_ == SL) {
1226b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      return_code = 2;
1236b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    } else {
1246b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      return_code = 1;
1256b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    }
1266b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
1276b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    if (!weight_.IsValid())
1286b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      weight_.CopyFrom(weight_backup);
1296b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    return return_code;
1306b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
1316b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
1326b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  return 0;
1336b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua}
1346b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
135a08525ea290ff4edc766eda1ec80388be866a79eDan Alberttemplate class StochasticLinearRanker<std::string, std::unordered_map<std::string, double> >;
136a08525ea290ff4edc766eda1ec80388be866a79eDan Alberttemplate class StochasticLinearRanker<int, std::unordered_map<int, double> >;
137a08525ea290ff4edc766eda1ec80388be866a79eDan Alberttemplate class StochasticLinearRanker<uint64, std::unordered_map<uint64, double> >;
1386b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
1396b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua}  // namespace learning_stochastic_linear
140