16b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua/* 26b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * Copyright (C) 2012 The Android Open Source Project 36b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * 46b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * Licensed under the Apache License, Version 2.0 (the "License"); 56b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * you may not use this file except in compliance with the License. 66b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * You may obtain a copy of the License at 76b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * 86b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * http://www.apache.org/licenses/LICENSE-2.0 96b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * 106b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * Unless required by applicable law or agreed to in writing, software 116b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * distributed under the License is distributed on an "AS IS" BASIS, 126b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 136b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * See the License for the specific language governing permissions and 146b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * limitations under the License. 156b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua */ 166b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 176b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua// 186b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua// This file contains the MulticlassPA class which implements a simple 196b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua// linear multi-class classifier based on the multi-prototype version of 206b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua// passive aggressive. 216b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 226b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#include "native/multiclass_pa.h" 236b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 246b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huausing std::vector; 256b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huausing std::pair; 266b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 276b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huanamespace learningfw { 286b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 296b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huafloat RandFloat() { 306b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return static_cast<float>(rand()) / RAND_MAX; 316b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua} 326b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 336b4eebc73439cbc3ddfb547444a341d1f9be7996Wei HuaMulticlassPA::MulticlassPA(int num_classes, 346b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua int num_dimensions, 356b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua float aggressiveness) 366b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua : num_classes_(num_classes), 376b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua num_dimensions_(num_dimensions), 386b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua aggressiveness_(aggressiveness) { 396b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua InitializeParameters(); 406b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua} 416b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 426b4eebc73439cbc3ddfb547444a341d1f9be7996Wei HuaMulticlassPA::~MulticlassPA() { 436b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua} 446b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 456b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huavoid MulticlassPA::InitializeParameters() { 466b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua parameters_.resize(num_classes_); 476b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua for (int i = 0; i < num_classes_; ++i) { 486b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua parameters_[i].resize(num_dimensions_); 496b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua for (int j = 0; j < num_dimensions_; ++j) { 506b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua parameters_[i][j] = 0.0; 516b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 526b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 536b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua} 546b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 556b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huaint MulticlassPA::PickAClassExcept(int target) { 566b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua int picked; 576b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua do { 586b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua picked = static_cast<int>(RandFloat() * num_classes_); 596b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // picked = static_cast<int>(random_.RandFloat() * num_classes_); 606b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } while (target == picked); 616b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return picked; 626b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua} 636b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 646b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huaint MulticlassPA::PickAnExample(int num_examples) { 656b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return static_cast<int>(RandFloat() * num_examples); 666b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua} 676b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 686b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huafloat MulticlassPA::Score(const vector<float>& inputs, 696b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const vector<float>& parameters) const { 706b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // CHECK_EQ(inputs.size(), parameters.size()); 716b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua float result = 0.0; 726b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua for (int i = 0; i < static_cast<int>(inputs.size()); ++i) { 736b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua result += inputs[i] * parameters[i]; 746b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 756b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return result; 766b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua} 776b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 786b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huafloat MulticlassPA::SparseScore(const vector<pair<int, float> >& inputs, 796b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const vector<float>& parameters) const { 806b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua float result = 0.0; 816b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua for (int i = 0; i < static_cast<int>(inputs.size()); ++i) { 826b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua //DCHECK_GE(inputs[i].first, 0); 836b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua //DCHECK_LT(inputs[i].first, parameters.size()); 846b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua result += inputs[i].second * parameters[inputs[i].first]; 856b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 866b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return result; 876b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua} 886b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 896b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huafloat MulticlassPA::L2NormSquare(const vector<float>& inputs) const { 906b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua float norm = 0; 916b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua for (int i = 0; i < static_cast<int>(inputs.size()); ++i) { 926b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua norm += inputs[i] * inputs[i]; 936b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 946b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return norm; 956b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua} 966b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 976b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huafloat MulticlassPA::SparseL2NormSquare( 986b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const vector<pair<int, float> >& inputs) const { 996b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua float norm = 0; 1006b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua for (int i = 0; i < static_cast<int>(inputs.size()); ++i) { 1016b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua norm += inputs[i].second * inputs[i].second; 1026b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1036b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return norm; 1046b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua} 1056b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 1066b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huafloat MulticlassPA::TrainOneExample(const vector<float>& inputs, int target) { 1076b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua //CHECK_GE(target, 0); 1086b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua //CHECK_LT(target, num_classes_); 1096b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua float target_class_score = Score(inputs, parameters_[target]); 1106b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // VLOG(1) << "target class " << target << " score " << target_class_score; 1116b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua int other_class = PickAClassExcept(target); 1126b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua float other_class_score = Score(inputs, parameters_[other_class]); 1136b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // VLOG(1) << "other class " << other_class << " score " << other_class_score; 1146b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua float loss = 1.0 - target_class_score + other_class_score; 1156b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua if (loss > 0.0) { 1166b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // Compute the learning rate according to PA-I. 1176b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua float twice_norm_square = L2NormSquare(inputs) * 2.0; 1186b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua if (twice_norm_square == 0.0) { 1196b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua twice_norm_square = kEpsilon; 1206b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1216b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua float rate = loss / twice_norm_square; 1226b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua if (rate > aggressiveness_) { 1236b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua rate = aggressiveness_; 1246b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1256b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // VLOG(1) << "loss = " << loss << " rate = " << rate; 1266b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // Modify the parameter vectors of the correct and wrong classes 1276b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua for (int i = 0; i < static_cast<int>(inputs.size()); ++i) { 1286b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // First modify the parameter value of the correct class 1296b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua parameters_[target][i] += rate * inputs[i]; 1306b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // Then modify the parameter value of the wrong class 1316b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua parameters_[other_class][i] -= rate * inputs[i]; 1326b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1336b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return loss; 1346b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1356b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return 0.0; 1366b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua} 1376b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 1386b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huafloat MulticlassPA::SparseTrainOneExample( 1396b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const vector<pair<int, float> >& inputs, int target) { 1406b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // CHECK_GE(target, 0); 1416b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // CHECK_LT(target, num_classes_); 1426b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua float target_class_score = SparseScore(inputs, parameters_[target]); 1436b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // VLOG(1) << "target class " << target << " score " << target_class_score; 1446b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua int other_class = PickAClassExcept(target); 1456b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua float other_class_score = SparseScore(inputs, parameters_[other_class]); 1466b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // VLOG(1) << "other class " << other_class << " score " << other_class_score; 1476b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua float loss = 1.0 - target_class_score + other_class_score; 1486b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua if (loss > 0.0) { 1496b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // Compute the learning rate according to PA-I. 1506b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua float twice_norm_square = SparseL2NormSquare(inputs) * 2.0; 1516b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua if (twice_norm_square == 0.0) { 1526b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua twice_norm_square = kEpsilon; 1536b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1546b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua float rate = loss / twice_norm_square; 1556b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua if (rate > aggressiveness_) { 1566b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua rate = aggressiveness_; 1576b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1586b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // VLOG(1) << "loss = " << loss << " rate = " << rate; 1596b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // Modify the parameter vectors of the correct and wrong classes 1606b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua for (int i = 0; i < static_cast<int>(inputs.size()); ++i) { 1616b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // First modify the parameter value of the correct class 1626b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua parameters_[target][inputs[i].first] += rate * inputs[i].second; 1636b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // Then modify the parameter value of the wrong class 1646b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua parameters_[other_class][inputs[i].first] -= rate * inputs[i].second; 1656b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1666b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return loss; 1676b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1686b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return 0.0; 1696b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua} 1706b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 1716b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huafloat MulticlassPA::Train(const vector<pair<vector<float>, int> >& data, 1726b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua int num_iterations) { 1736b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua int num_examples = data.size(); 1746b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua float total_loss = 0.0; 1756b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua for (int t = 0; t < num_iterations; ++t) { 1766b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua int index = PickAnExample(num_examples); 1776b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua float loss_t = TrainOneExample(data[index].first, data[index].second); 1786b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua total_loss += loss_t; 1796b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1806b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return total_loss / static_cast<float>(num_iterations); 1816b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua} 1826b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 1836b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huafloat MulticlassPA::SparseTrain( 1846b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const vector<pair<vector<pair<int, float> >, int> >& data, 1856b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua int num_iterations) { 1866b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua int num_examples = data.size(); 1876b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua float total_loss = 0.0; 1886b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua for (int t = 0; t < num_iterations; ++t) { 1896b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua int index = PickAnExample(num_examples); 1906b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua float loss_t = SparseTrainOneExample(data[index].first, data[index].second); 1916b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua total_loss += loss_t; 1926b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1936b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return total_loss / static_cast<float>(num_iterations); 1946b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua} 1956b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 1966b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huaint MulticlassPA::GetClass(const vector<float>& inputs) { 1976b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua int best_class = -1; 1986b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua float best_score = -10000.0; 1996b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // float best_score = -MathLimits<float>::kMax; 2006b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua for (int i = 0; i < num_classes_; ++i) { 2016b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua float score_i = Score(inputs, parameters_[i]); 2026b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua if (score_i > best_score) { 2036b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua best_score = score_i; 2046b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua best_class = i; 2056b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 2066b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 2076b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return best_class; 2086b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua} 2096b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 2106b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huaint MulticlassPA::SparseGetClass(const vector<pair<int, float> >& inputs) { 2116b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua int best_class = -1; 2126b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua float best_score = -10000.0; 2136b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua //float best_score = -MathLimits<float>::kMax; 2146b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua for (int i = 0; i < num_classes_; ++i) { 2156b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua float score_i = SparseScore(inputs, parameters_[i]); 2166b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua if (score_i > best_score) { 2176b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua best_score = score_i; 2186b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua best_class = i; 2196b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 2206b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 2216b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return best_class; 2226b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua} 2236b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 2246b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huafloat MulticlassPA::Test(const vector<pair<vector<float>, int> >& data) { 2256b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua int num_examples = data.size(); 2266b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua float total_error = 0.0; 2276b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua for (int t = 0; t < num_examples; ++t) { 2286b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua int best_class = GetClass(data[t].first); 2296b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua if (best_class != data[t].second) { 2306b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua ++total_error; 2316b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 2326b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 2336b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return total_error / num_examples; 2346b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua} 2356b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 2366b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huafloat MulticlassPA::SparseTest( 2376b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const vector<pair<vector<pair<int, float> >, int> >& data) { 2386b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua int num_examples = data.size(); 2396b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua float total_error = 0.0; 2406b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua for (int t = 0; t < num_examples; ++t) { 2416b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua int best_class = SparseGetClass(data[t].first); 2426b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua if (best_class != data[t].second) { 2436b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua ++total_error; 2446b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 2456b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 2466b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return total_error / num_examples; 2476b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua} 2486b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua} // namespace learningfw 249