16b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua/*
26b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * Copyright (C) 2012 The Android Open Source Project
36b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua *
46b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * Licensed under the Apache License, Version 2.0 (the "License");
56b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * you may not use this file except in compliance with the License.
66b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * You may obtain a copy of the License at
76b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua *
86b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua *      http://www.apache.org/licenses/LICENSE-2.0
96b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua *
106b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * Unless required by applicable law or agreed to in writing, software
116b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * distributed under the License is distributed on an "AS IS" BASIS,
126b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
136b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * See the License for the specific language governing permissions and
146b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * limitations under the License.
156b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua */
166b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
176b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua// This file contains the MulticlassPA class which implements a simple
186b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua// linear multi-class classifier based on the multi-prototype version of
196b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua// passive aggressive.
206b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
216b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#ifndef LEARNINGFW_MULTICLASS_PA_H_
226b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#define LEARNINGFW_MULTICLASS_PA_H_
236b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
246b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#include <vector>
256b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#include <cmath>
266b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
276b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huaconst float kEpsilon = 1.0e-4;
286b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
296b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huanamespace learningfw {
306b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
316b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huaclass MulticlassPA {
326b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua public:
336b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  MulticlassPA(int num_classes,
346b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua               int num_dimensions,
356b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua               float aggressiveness);
366b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  virtual ~MulticlassPA();
376b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
386b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // Initialize all parameters to 0.0.
396b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  void InitializeParameters();
406b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
416b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // Returns a random class that is different from the target class.
426b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  int PickAClassExcept(int target);
436b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
446b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // Returns a random example.
456b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  int PickAnExample(int num_examples);
466b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
476b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // Computes the score of a given input vector for a given parameter
486b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // vector, by computing the dot product between the two.
496b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  float Score(const std::vector<float>& inputs,
506b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua              const std::vector<float>& parameters) const;
516b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  float SparseScore(const std::vector<std::pair<int, float> >& inputs,
526b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua                    const std::vector<float>& parameters) const;
536b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
546b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // Returns the square of the L2 norm.
556b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  float L2NormSquare(const std::vector<float>& inputs) const;
566b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  float SparseL2NormSquare(const std::vector<std::pair<int, float> >& inputs) const;
576b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
586b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // Verify if the given example is correctly classified with margin with
596b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // respect to a random class.  If not, then modifies the corresponding
606b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // parameters using passive-aggressive.
616b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  virtual float TrainOneExample(const std::vector<float>& inputs, int target);
626b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  virtual float SparseTrainOneExample(
636b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      const std::vector<std::pair<int, float> >& inputs, int target);
646b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
656b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // Iteratively train the model for num_iterations on the given dataset.
666b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  float Train(const std::vector<std::pair<std::vector<float>, int> >& data,
676b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua              int num_iterations);
686b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  float SparseTrain(
696b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      const std::vector<std::pair<std::vector<std::pair<int, float> >, int> >& data,
706b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      int num_iterations);
716b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
726b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // Returns the best class for a given input vector.
736b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  virtual int GetClass(const std::vector<float>& inputs);
746b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  virtual int SparseGetClass(const std::vector<std::pair<int, float> >& inputs);
756b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
766b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // Computes the test error of a given test set on the current model.
776b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  float Test(const std::vector<std::pair<std::vector<float>, int> >& data);
786b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  float SparseTest(
796b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      const std::vector<std::pair<std::vector<std::pair<int, float> >, int> >& data);
806b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
816b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // A few accessors used by the sub-classes.
826b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  inline float aggressiveness() const {
836b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    return aggressiveness_;
846b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
856b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
866b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  inline std::vector<std::vector<float> >& parameters() {
876b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    return parameters_;
886b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
896b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
906b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  inline std::vector<std::vector<float> >* mutable_parameters() {
916b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    return &parameters_;
926b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
936b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
946b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  inline int num_classes() const {
956b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    return num_classes_;
966b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
976b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
986b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  inline int num_dimensions() const {
996b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    return num_dimensions_;
1006b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
1016b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
1026b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua private:
1036b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // Keeps the current parameter vector.
1046b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  std::vector<std::vector<float> > parameters_;
1056b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
1066b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // The number of classes of the problem.
1076b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  int num_classes_;
1086b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
1096b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // The number of dimensions of the input vectors.
1106b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  int num_dimensions_;
1116b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
1126b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // Controls how "aggressive" training should be.
1136b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  float aggressiveness_;
1146b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
1156b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua};
1166b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua}  // namespace learningfw
1176b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#endif  // LEARNINGFW_MULTICLASS_PA_H_
118