1/*
2 * Copyright (C) 2012 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include <algorithm>
18#include <stdlib.h>
19
20#include "stochastic_linear_ranker.h"
21
22namespace learning_stochastic_linear {
23
24template<class Key, class Hash>
25void StochasticLinearRanker<Key, Hash>::UpdateSubGradient(
26    const SparseWeightVector<Key, Hash> &positive,
27    const SparseWeightVector<Key, Hash> &negative,
28    const double learning_rate,
29    const double positive_score,
30    const double negative_score,
31    const int32 gradient_l0_norm) {
32  SparseWeightVector<Key, Hash> gradient;
33  double final_learning_rate;
34  gradient.AdditiveWeightUpdate(1.0, positive, 0.0);
35  gradient.AdditiveWeightUpdate(-1.0, negative, 0.0);
36  if (update_type_ == FULL_CS || update_type_ == REG_CS) {
37    const double loss = std::max(0.0, (1 - positive_score + negative_score));
38    const double gradient_norm = gradient.L2Norm();
39    const double kMinGradientNorm = 1e-8;
40    const double kMaxGradientNorm = 1e8;
41    if (gradient_norm < kMinGradientNorm || gradient_norm > kMaxGradientNorm)
42      return;
43    if (update_type_ == FULL_CS)
44      final_learning_rate =
45          std::min(lambda_, loss / (gradient_norm * gradient_norm));
46    else
47      final_learning_rate =
48          loss / (gradient_norm * gradient_norm + 1 / (2 * lambda_));
49  } else {
50    gradient.AdditiveWeightUpdate(-lambda_, weight_, 0.0);
51    final_learning_rate = learning_rate;
52  }
53  if (gradient_l0_norm > 0) {
54    gradient.ReprojectL0(gradient_l0_norm);
55  }
56
57  if (gradient.IsValid())
58    weight_.AdditiveWeightUpdate(final_learning_rate, gradient, 0.0);
59}
60
61template<class Key, class Hash>
62int StochasticLinearRanker<Key, Hash>::UpdateClassifier(
63    const SparseWeightVector<Key, Hash> &positive,
64    const SparseWeightVector<Key, Hash> &negative) {
65  // Create a backup of the weight vector in case the iteration results in
66  // unbounded weights.
67  SparseWeightVector<Key, Hash> weight_backup;
68  weight_backup.CopyFrom(weight_);
69
70  const double positive_score = ScoreSample(positive);
71  const double negative_score = ScoreSample(negative);
72  if ((positive_score - negative_score) < 1) {
73    ++mini_batch_counter_;
74    if ((mini_batch_counter_ % mini_batch_size_ == 0) ||
75        (iteration_num_ == 0)) {
76      ++iteration_num_;
77      mini_batch_counter_ = 0;
78    }
79    learning_rate_controller_.IncrementSample();
80    double learning_rate = learning_rate_controller_.GetLearningRate();
81
82    if (rank_loss_type_ == PAIRWISE) {
83      UpdateSubGradient(positive, negative, learning_rate,
84                        positive_score, negative_score,
85                        gradient_l0_norm_);
86    } else if (rank_loss_type_ == RECIPROCAL_RANK) {
87      const double current_negative_score = ScoreSample(current_negative_);
88      if ((negative_score > current_negative_score) ||
89          ((rand()/RAND_MAX) < acceptence_probability_)) {
90        UpdateSubGradient(positive, negative, learning_rate,
91                          positive_score, negative_score,
92                          gradient_l0_norm_);
93        current_negative_.Clear();
94        current_negative_.LoadWeightVector(negative);
95      } else {
96        UpdateSubGradient(positive, current_negative_, learning_rate,
97                          positive_score, negative_score,
98                          gradient_l0_norm_);
99      }
100    } else {
101      ALOGE("Unknown rank loss type: %d", rank_loss_type_);
102    }
103
104    int return_code;
105    if ((mini_batch_counter_ == 0) && (update_type_ == SL)) {
106      return_code = 1;
107      switch (regularization_type_) {
108        case L1:
109          weight_.ReprojectL1(norm_constraint_);
110          break;
111        case L2:
112          weight_.ReprojectL2(norm_constraint_);
113          break;
114        case L0:
115          weight_.ReprojectL0(norm_constraint_);
116          break;
117        default:
118          ALOGE("Unsupported optimization type specified");
119          return_code = -1;
120      }
121    } else if (update_type_ == SL) {
122      return_code = 2;
123    } else {
124      return_code = 1;
125    }
126
127    if (!weight_.IsValid())
128      weight_.CopyFrom(weight_backup);
129    return return_code;
130  }
131
132  return 0;
133}
134
135template class StochasticLinearRanker<std::string, std::unordered_map<std::string, double> >;
136template class StochasticLinearRanker<int, std::unordered_map<int, double> >;
137template class StochasticLinearRanker<uint64, std::unordered_map<uint64, double> >;
138
139}  // namespace learning_stochastic_linear
140