16b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua/* 26b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * Copyright (C) 2012 The Android Open Source Project 36b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * 46b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * Licensed under the Apache License, Version 2.0 (the "License"); 56b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * you may not use this file except in compliance with the License. 66b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * You may obtain a copy of the License at 76b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * 86b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * http://www.apache.org/licenses/LICENSE-2.0 96b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * 106b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * Unless required by applicable law or agreed to in writing, software 116b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * distributed under the License is distributed on an "AS IS" BASIS, 126b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 136b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * See the License for the specific language governing permissions and 146b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * limitations under the License. 156b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua */ 166b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 176b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua// This file contains the MulticlassPA class which implements a simple 186b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua// linear multi-class classifier based on the multi-prototype version of 196b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua// passive aggressive. 206b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 216b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#ifndef LEARNINGFW_MULTICLASS_PA_H_ 226b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#define LEARNINGFW_MULTICLASS_PA_H_ 236b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 246b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#include <vector> 256b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#include <cmath> 266b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 276b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huaconst float kEpsilon = 1.0e-4; 286b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 296b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huanamespace learningfw { 306b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 316b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huaclass MulticlassPA { 326b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua public: 336b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua MulticlassPA(int num_classes, 346b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua int num_dimensions, 356b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua float aggressiveness); 366b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua virtual ~MulticlassPA(); 376b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 386b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // Initialize all parameters to 0.0. 396b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua void InitializeParameters(); 406b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 416b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // Returns a random class that is different from the target class. 426b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua int PickAClassExcept(int target); 436b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 446b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // Returns a random example. 456b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua int PickAnExample(int num_examples); 466b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 476b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // Computes the score of a given input vector for a given parameter 486b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // vector, by computing the dot product between the two. 496b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua float Score(const std::vector<float>& inputs, 506b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const std::vector<float>& parameters) const; 516b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua float SparseScore(const std::vector<std::pair<int, float> >& inputs, 526b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const std::vector<float>& parameters) const; 536b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 546b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // Returns the square of the L2 norm. 556b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua float L2NormSquare(const std::vector<float>& inputs) const; 566b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua float SparseL2NormSquare(const std::vector<std::pair<int, float> >& inputs) const; 576b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 586b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // Verify if the given example is correctly classified with margin with 596b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // respect to a random class. If not, then modifies the corresponding 606b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // parameters using passive-aggressive. 616b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua virtual float TrainOneExample(const std::vector<float>& inputs, int target); 626b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua virtual float SparseTrainOneExample( 636b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const std::vector<std::pair<int, float> >& inputs, int target); 646b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 656b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // Iteratively train the model for num_iterations on the given dataset. 666b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua float Train(const std::vector<std::pair<std::vector<float>, int> >& data, 676b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua int num_iterations); 686b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua float SparseTrain( 696b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const std::vector<std::pair<std::vector<std::pair<int, float> >, int> >& data, 706b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua int num_iterations); 716b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 726b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // Returns the best class for a given input vector. 736b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua virtual int GetClass(const std::vector<float>& inputs); 746b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua virtual int SparseGetClass(const std::vector<std::pair<int, float> >& inputs); 756b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 766b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // Computes the test error of a given test set on the current model. 776b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua float Test(const std::vector<std::pair<std::vector<float>, int> >& data); 786b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua float SparseTest( 796b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const std::vector<std::pair<std::vector<std::pair<int, float> >, int> >& data); 806b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 816b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // A few accessors used by the sub-classes. 826b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua inline float aggressiveness() const { 836b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return aggressiveness_; 846b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 856b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 866b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua inline std::vector<std::vector<float> >& parameters() { 876b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return parameters_; 886b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 896b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 906b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua inline std::vector<std::vector<float> >* mutable_parameters() { 916b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return ¶meters_; 926b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 936b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 946b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua inline int num_classes() const { 956b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return num_classes_; 966b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 976b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 986b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua inline int num_dimensions() const { 996b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return num_dimensions_; 1006b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1016b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 1026b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua private: 1036b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // Keeps the current parameter vector. 1046b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua std::vector<std::vector<float> > parameters_; 1056b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 1066b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // The number of classes of the problem. 1076b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua int num_classes_; 1086b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 1096b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // The number of dimensions of the input vectors. 1106b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua int num_dimensions_; 1116b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 1126b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // Controls how "aggressive" training should be. 1136b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua float aggressiveness_; 1146b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 1156b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua}; 1166b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua} // namespace learningfw 1176b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#endif // LEARNINGFW_MULTICLASS_PA_H_ 118