1/*
2 * Copyright (C) 2012 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17//
18// This file contains the MulticlassPA class which implements a simple
19// linear multi-class classifier based on the multi-prototype version of
20// passive aggressive.
21
22#include "native/multiclass_pa.h"
23
24#include <stdlib.h>
25
26using std::vector;
27using std::pair;
28
29namespace learningfw {
30
31float RandFloat() {
32  return static_cast<float>(rand()) / RAND_MAX;
33}
34
35MulticlassPA::MulticlassPA(int num_classes,
36                           int num_dimensions,
37                           float aggressiveness)
38    : num_classes_(num_classes),
39      num_dimensions_(num_dimensions),
40      aggressiveness_(aggressiveness) {
41  InitializeParameters();
42}
43
44MulticlassPA::~MulticlassPA() {
45}
46
47void MulticlassPA::InitializeParameters() {
48  parameters_.resize(num_classes_);
49  for (int i = 0; i < num_classes_; ++i) {
50    parameters_[i].resize(num_dimensions_);
51    for (int j = 0; j < num_dimensions_; ++j) {
52      parameters_[i][j] = 0.0;
53    }
54  }
55}
56
57int MulticlassPA::PickAClassExcept(int target) {
58  int picked;
59  do {
60    picked = static_cast<int>(RandFloat() * num_classes_);
61    //    picked = static_cast<int>(random_.RandFloat() * num_classes_);
62  } while (target == picked);
63  return picked;
64}
65
66int MulticlassPA::PickAnExample(int num_examples) {
67  return static_cast<int>(RandFloat() * num_examples);
68}
69
70float MulticlassPA::Score(const vector<float>& inputs,
71                          const vector<float>& parameters) const {
72  // CHECK_EQ(inputs.size(), parameters.size());
73  float result = 0.0;
74  for (int i = 0; i < static_cast<int>(inputs.size()); ++i) {
75    result += inputs[i] * parameters[i];
76  }
77  return result;
78}
79
80float MulticlassPA::SparseScore(const vector<pair<int, float> >& inputs,
81                                const vector<float>& parameters) const {
82  float result = 0.0;
83  for (int i = 0; i < static_cast<int>(inputs.size()); ++i) {
84    //DCHECK_GE(inputs[i].first, 0);
85    //DCHECK_LT(inputs[i].first, parameters.size());
86    result += inputs[i].second * parameters[inputs[i].first];
87  }
88  return result;
89}
90
91float MulticlassPA::L2NormSquare(const vector<float>& inputs) const {
92  float norm = 0;
93  for (int i = 0; i < static_cast<int>(inputs.size()); ++i) {
94    norm += inputs[i] * inputs[i];
95  }
96  return norm;
97}
98
99float MulticlassPA::SparseL2NormSquare(
100    const vector<pair<int, float> >& inputs) const {
101  float norm = 0;
102  for (int i = 0; i < static_cast<int>(inputs.size()); ++i) {
103    norm += inputs[i].second * inputs[i].second;
104  }
105  return norm;
106}
107
108float MulticlassPA::TrainOneExample(const vector<float>& inputs, int target) {
109  //CHECK_GE(target, 0);
110  //CHECK_LT(target, num_classes_);
111  float target_class_score = Score(inputs, parameters_[target]);
112  //  VLOG(1) << "target class " << target << " score " << target_class_score;
113  int other_class = PickAClassExcept(target);
114  float other_class_score = Score(inputs, parameters_[other_class]);
115  //  VLOG(1) << "other class " << other_class << " score " << other_class_score;
116  float loss = 1.0 - target_class_score + other_class_score;
117  if (loss > 0.0) {
118    // Compute the learning rate according to PA-I.
119    float twice_norm_square = L2NormSquare(inputs) * 2.0;
120    if (twice_norm_square == 0.0) {
121      twice_norm_square = kEpsilon;
122    }
123    float rate = loss / twice_norm_square;
124    if (rate > aggressiveness_) {
125      rate = aggressiveness_;
126    }
127    //    VLOG(1) << "loss = " << loss << " rate = " << rate;
128    // Modify the parameter vectors of the correct and wrong classes
129    for (int i = 0; i < static_cast<int>(inputs.size()); ++i) {
130      // First modify the parameter value of the correct class
131      parameters_[target][i] += rate * inputs[i];
132      // Then modify the parameter value of the wrong class
133      parameters_[other_class][i] -= rate * inputs[i];
134    }
135    return loss;
136  }
137  return 0.0;
138}
139
140float MulticlassPA::SparseTrainOneExample(
141    const vector<pair<int, float> >& inputs, int target) {
142  // CHECK_GE(target, 0);
143  // CHECK_LT(target, num_classes_);
144  float target_class_score = SparseScore(inputs, parameters_[target]);
145  //  VLOG(1) << "target class " << target << " score " << target_class_score;
146  int other_class = PickAClassExcept(target);
147  float other_class_score = SparseScore(inputs, parameters_[other_class]);
148  //  VLOG(1) << "other class " << other_class << " score " << other_class_score;
149  float loss = 1.0 - target_class_score + other_class_score;
150  if (loss > 0.0) {
151    // Compute the learning rate according to PA-I.
152    float twice_norm_square = SparseL2NormSquare(inputs) * 2.0;
153    if (twice_norm_square == 0.0) {
154      twice_norm_square = kEpsilon;
155    }
156    float rate = loss / twice_norm_square;
157    if (rate > aggressiveness_) {
158      rate = aggressiveness_;
159    }
160    //    VLOG(1) << "loss = " << loss << " rate = " << rate;
161    // Modify the parameter vectors of the correct and wrong classes
162    for (int i = 0; i < static_cast<int>(inputs.size()); ++i) {
163      // First modify the parameter value of the correct class
164      parameters_[target][inputs[i].first] += rate * inputs[i].second;
165      // Then modify the parameter value of the wrong class
166      parameters_[other_class][inputs[i].first] -= rate * inputs[i].second;
167    }
168    return loss;
169  }
170  return 0.0;
171}
172
173float MulticlassPA::Train(const vector<pair<vector<float>, int> >& data,
174                          int num_iterations) {
175  int num_examples = data.size();
176  float total_loss = 0.0;
177  for (int t = 0; t < num_iterations; ++t) {
178    int index = PickAnExample(num_examples);
179    float loss_t = TrainOneExample(data[index].first, data[index].second);
180    total_loss += loss_t;
181  }
182  return total_loss / static_cast<float>(num_iterations);
183}
184
185float MulticlassPA::SparseTrain(
186    const vector<pair<vector<pair<int, float> >, int> >& data,
187    int num_iterations) {
188  int num_examples = data.size();
189  float total_loss = 0.0;
190  for (int t = 0; t < num_iterations; ++t) {
191    int index = PickAnExample(num_examples);
192    float loss_t = SparseTrainOneExample(data[index].first, data[index].second);
193    total_loss += loss_t;
194  }
195  return total_loss / static_cast<float>(num_iterations);
196}
197
198int MulticlassPA::GetClass(const vector<float>& inputs) {
199  int best_class = -1;
200  float best_score = -10000.0;
201  // float best_score = -MathLimits<float>::kMax;
202  for (int i = 0; i < num_classes_; ++i) {
203    float score_i = Score(inputs, parameters_[i]);
204    if (score_i > best_score) {
205      best_score = score_i;
206      best_class = i;
207    }
208  }
209  return best_class;
210}
211
212int MulticlassPA::SparseGetClass(const vector<pair<int, float> >& inputs) {
213  int best_class = -1;
214  float best_score = -10000.0;
215  //float best_score = -MathLimits<float>::kMax;
216  for (int i = 0; i < num_classes_; ++i) {
217    float score_i = SparseScore(inputs, parameters_[i]);
218    if (score_i > best_score) {
219      best_score = score_i;
220      best_class = i;
221    }
222  }
223  return best_class;
224}
225
226float MulticlassPA::Test(const vector<pair<vector<float>, int> >& data) {
227  int num_examples = data.size();
228  float total_error = 0.0;
229  for (int t = 0; t < num_examples; ++t) {
230    int best_class = GetClass(data[t].first);
231    if (best_class != data[t].second) {
232      ++total_error;
233    }
234  }
235  return total_error / num_examples;
236}
237
238float MulticlassPA::SparseTest(
239    const vector<pair<vector<pair<int, float> >, int> >& data) {
240  int num_examples = data.size();
241  float total_error = 0.0;
242  for (int t = 0; t < num_examples; ++t) {
243    int best_class = SparseGetClass(data[t].first);
244    if (best_class != data[t].second) {
245      ++total_error;
246    }
247  }
248  return total_error / num_examples;
249}
250}  // namespace learningfw
251