1/*
2 * Copyright (C) 2012 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17//
18// This file contains the MulticlassPA class which implements a simple
19// linear multi-class classifier based on the multi-prototype version of
20// passive aggressive.
21
22#include "native/multiclass_pa.h"
23
24using std::vector;
25using std::pair;
26
27namespace learningfw {
28
29float RandFloat() {
30  return static_cast<float>(rand()) / RAND_MAX;
31}
32
33MulticlassPA::MulticlassPA(int num_classes,
34                           int num_dimensions,
35                           float aggressiveness)
36    : num_classes_(num_classes),
37      num_dimensions_(num_dimensions),
38      aggressiveness_(aggressiveness) {
39  InitializeParameters();
40}
41
42MulticlassPA::~MulticlassPA() {
43}
44
45void MulticlassPA::InitializeParameters() {
46  parameters_.resize(num_classes_);
47  for (int i = 0; i < num_classes_; ++i) {
48    parameters_[i].resize(num_dimensions_);
49    for (int j = 0; j < num_dimensions_; ++j) {
50      parameters_[i][j] = 0.0;
51    }
52  }
53}
54
55int MulticlassPA::PickAClassExcept(int target) {
56  int picked;
57  do {
58    picked = static_cast<int>(RandFloat() * num_classes_);
59    //    picked = static_cast<int>(random_.RandFloat() * num_classes_);
60  } while (target == picked);
61  return picked;
62}
63
64int MulticlassPA::PickAnExample(int num_examples) {
65  return static_cast<int>(RandFloat() * num_examples);
66}
67
68float MulticlassPA::Score(const vector<float>& inputs,
69                          const vector<float>& parameters) const {
70  // CHECK_EQ(inputs.size(), parameters.size());
71  float result = 0.0;
72  for (int i = 0; i < static_cast<int>(inputs.size()); ++i) {
73    result += inputs[i] * parameters[i];
74  }
75  return result;
76}
77
78float MulticlassPA::SparseScore(const vector<pair<int, float> >& inputs,
79                                const vector<float>& parameters) const {
80  float result = 0.0;
81  for (int i = 0; i < static_cast<int>(inputs.size()); ++i) {
82    //DCHECK_GE(inputs[i].first, 0);
83    //DCHECK_LT(inputs[i].first, parameters.size());
84    result += inputs[i].second * parameters[inputs[i].first];
85  }
86  return result;
87}
88
89float MulticlassPA::L2NormSquare(const vector<float>& inputs) const {
90  float norm = 0;
91  for (int i = 0; i < static_cast<int>(inputs.size()); ++i) {
92    norm += inputs[i] * inputs[i];
93  }
94  return norm;
95}
96
97float MulticlassPA::SparseL2NormSquare(
98    const vector<pair<int, float> >& inputs) const {
99  float norm = 0;
100  for (int i = 0; i < static_cast<int>(inputs.size()); ++i) {
101    norm += inputs[i].second * inputs[i].second;
102  }
103  return norm;
104}
105
106float MulticlassPA::TrainOneExample(const vector<float>& inputs, int target) {
107  //CHECK_GE(target, 0);
108  //CHECK_LT(target, num_classes_);
109  float target_class_score = Score(inputs, parameters_[target]);
110  //  VLOG(1) << "target class " << target << " score " << target_class_score;
111  int other_class = PickAClassExcept(target);
112  float other_class_score = Score(inputs, parameters_[other_class]);
113  //  VLOG(1) << "other class " << other_class << " score " << other_class_score;
114  float loss = 1.0 - target_class_score + other_class_score;
115  if (loss > 0.0) {
116    // Compute the learning rate according to PA-I.
117    float twice_norm_square = L2NormSquare(inputs) * 2.0;
118    if (twice_norm_square == 0.0) {
119      twice_norm_square = kEpsilon;
120    }
121    float rate = loss / twice_norm_square;
122    if (rate > aggressiveness_) {
123      rate = aggressiveness_;
124    }
125    //    VLOG(1) << "loss = " << loss << " rate = " << rate;
126    // Modify the parameter vectors of the correct and wrong classes
127    for (int i = 0; i < static_cast<int>(inputs.size()); ++i) {
128      // First modify the parameter value of the correct class
129      parameters_[target][i] += rate * inputs[i];
130      // Then modify the parameter value of the wrong class
131      parameters_[other_class][i] -= rate * inputs[i];
132    }
133    return loss;
134  }
135  return 0.0;
136}
137
138float MulticlassPA::SparseTrainOneExample(
139    const vector<pair<int, float> >& inputs, int target) {
140  // CHECK_GE(target, 0);
141  // CHECK_LT(target, num_classes_);
142  float target_class_score = SparseScore(inputs, parameters_[target]);
143  //  VLOG(1) << "target class " << target << " score " << target_class_score;
144  int other_class = PickAClassExcept(target);
145  float other_class_score = SparseScore(inputs, parameters_[other_class]);
146  //  VLOG(1) << "other class " << other_class << " score " << other_class_score;
147  float loss = 1.0 - target_class_score + other_class_score;
148  if (loss > 0.0) {
149    // Compute the learning rate according to PA-I.
150    float twice_norm_square = SparseL2NormSquare(inputs) * 2.0;
151    if (twice_norm_square == 0.0) {
152      twice_norm_square = kEpsilon;
153    }
154    float rate = loss / twice_norm_square;
155    if (rate > aggressiveness_) {
156      rate = aggressiveness_;
157    }
158    //    VLOG(1) << "loss = " << loss << " rate = " << rate;
159    // Modify the parameter vectors of the correct and wrong classes
160    for (int i = 0; i < static_cast<int>(inputs.size()); ++i) {
161      // First modify the parameter value of the correct class
162      parameters_[target][inputs[i].first] += rate * inputs[i].second;
163      // Then modify the parameter value of the wrong class
164      parameters_[other_class][inputs[i].first] -= rate * inputs[i].second;
165    }
166    return loss;
167  }
168  return 0.0;
169}
170
171float MulticlassPA::Train(const vector<pair<vector<float>, int> >& data,
172                          int num_iterations) {
173  int num_examples = data.size();
174  float total_loss = 0.0;
175  for (int t = 0; t < num_iterations; ++t) {
176    int index = PickAnExample(num_examples);
177    float loss_t = TrainOneExample(data[index].first, data[index].second);
178    total_loss += loss_t;
179  }
180  return total_loss / static_cast<float>(num_iterations);
181}
182
183float MulticlassPA::SparseTrain(
184    const vector<pair<vector<pair<int, float> >, int> >& data,
185    int num_iterations) {
186  int num_examples = data.size();
187  float total_loss = 0.0;
188  for (int t = 0; t < num_iterations; ++t) {
189    int index = PickAnExample(num_examples);
190    float loss_t = SparseTrainOneExample(data[index].first, data[index].second);
191    total_loss += loss_t;
192  }
193  return total_loss / static_cast<float>(num_iterations);
194}
195
196int MulticlassPA::GetClass(const vector<float>& inputs) {
197  int best_class = -1;
198  float best_score = -10000.0;
199  // float best_score = -MathLimits<float>::kMax;
200  for (int i = 0; i < num_classes_; ++i) {
201    float score_i = Score(inputs, parameters_[i]);
202    if (score_i > best_score) {
203      best_score = score_i;
204      best_class = i;
205    }
206  }
207  return best_class;
208}
209
210int MulticlassPA::SparseGetClass(const vector<pair<int, float> >& inputs) {
211  int best_class = -1;
212  float best_score = -10000.0;
213  //float best_score = -MathLimits<float>::kMax;
214  for (int i = 0; i < num_classes_; ++i) {
215    float score_i = SparseScore(inputs, parameters_[i]);
216    if (score_i > best_score) {
217      best_score = score_i;
218      best_class = i;
219    }
220  }
221  return best_class;
222}
223
224float MulticlassPA::Test(const vector<pair<vector<float>, int> >& data) {
225  int num_examples = data.size();
226  float total_error = 0.0;
227  for (int t = 0; t < num_examples; ++t) {
228    int best_class = GetClass(data[t].first);
229    if (best_class != data[t].second) {
230      ++total_error;
231    }
232  }
233  return total_error / num_examples;
234}
235
236float MulticlassPA::SparseTest(
237    const vector<pair<vector<pair<int, float> >, int> >& data) {
238  int num_examples = data.size();
239  float total_error = 0.0;
240  for (int t = 0; t < num_examples; ++t) {
241    int best_class = SparseGetClass(data[t].first);
242    if (best_class != data[t].second) {
243      ++total_error;
244    }
245  }
246  return total_error / num_examples;
247}
248}  // namespace learningfw
249