1/*
2 * Copyright (C) 2012 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17// This file contains the MulticlassPA class which implements a simple
18// linear multi-class classifier based on the multi-prototype version of
19// passive aggressive.
20
21#ifndef LEARNINGFW_MULTICLASS_PA_H_
22#define LEARNINGFW_MULTICLASS_PA_H_
23
24#include <vector>
25#include <cmath>
26
27const float kEpsilon = 1.0e-4;
28
29namespace learningfw {
30
31class MulticlassPA {
32 public:
33  MulticlassPA(int num_classes,
34               int num_dimensions,
35               float aggressiveness);
36  virtual ~MulticlassPA();
37
38  // Initialize all parameters to 0.0.
39  void InitializeParameters();
40
41  // Returns a random class that is different from the target class.
42  int PickAClassExcept(int target);
43
44  // Returns a random example.
45  int PickAnExample(int num_examples);
46
47  // Computes the score of a given input vector for a given parameter
48  // vector, by computing the dot product between the two.
49  float Score(const std::vector<float>& inputs,
50              const std::vector<float>& parameters) const;
51  float SparseScore(const std::vector<std::pair<int, float> >& inputs,
52                    const std::vector<float>& parameters) const;
53
54  // Returns the square of the L2 norm.
55  float L2NormSquare(const std::vector<float>& inputs) const;
56  float SparseL2NormSquare(const std::vector<std::pair<int, float> >& inputs) const;
57
58  // Verify if the given example is correctly classified with margin with
59  // respect to a random class.  If not, then modifies the corresponding
60  // parameters using passive-aggressive.
61  virtual float TrainOneExample(const std::vector<float>& inputs, int target);
62  virtual float SparseTrainOneExample(
63      const std::vector<std::pair<int, float> >& inputs, int target);
64
65  // Iteratively train the model for num_iterations on the given dataset.
66  float Train(const std::vector<std::pair<std::vector<float>, int> >& data,
67              int num_iterations);
68  float SparseTrain(
69      const std::vector<std::pair<std::vector<std::pair<int, float> >, int> >& data,
70      int num_iterations);
71
72  // Returns the best class for a given input vector.
73  virtual int GetClass(const std::vector<float>& inputs);
74  virtual int SparseGetClass(const std::vector<std::pair<int, float> >& inputs);
75
76  // Computes the test error of a given test set on the current model.
77  float Test(const std::vector<std::pair<std::vector<float>, int> >& data);
78  float SparseTest(
79      const std::vector<std::pair<std::vector<std::pair<int, float> >, int> >& data);
80
81  // A few accessors used by the sub-classes.
82  inline float aggressiveness() const {
83    return aggressiveness_;
84  }
85
86  inline std::vector<std::vector<float> >& parameters() {
87    return parameters_;
88  }
89
90  inline std::vector<std::vector<float> >* mutable_parameters() {
91    return &parameters_;
92  }
93
94  inline int num_classes() const {
95    return num_classes_;
96  }
97
98  inline int num_dimensions() const {
99    return num_dimensions_;
100  }
101
102 private:
103  // Keeps the current parameter vector.
104  std::vector<std::vector<float> > parameters_;
105
106  // The number of classes of the problem.
107  int num_classes_;
108
109  // The number of dimensions of the input vectors.
110  int num_dimensions_;
111
112  // Controls how "aggressive" training should be.
113  float aggressiveness_;
114
115};
116}  // namespace learningfw
117#endif  // LEARNINGFW_MULTICLASS_PA_H_
118