16b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua/* 26b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * Copyright (C) 2012 The Android Open Source Project 36b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * 46b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * Licensed under the Apache License, Version 2.0 (the "License"); 56b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * you may not use this file except in compliance with the License. 66b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * You may obtain a copy of the License at 76b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * 86b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * http://www.apache.org/licenses/LICENSE-2.0 96b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * 106b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * Unless required by applicable law or agreed to in writing, software 116b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * distributed under the License is distributed on an "AS IS" BASIS, 126b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 136b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * See the License for the specific language governing permissions and 146b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * limitations under the License. 156b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua */ 166b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 176b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua// 186b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua// This file contains the MulticlassPA class which implements a simple 196b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua// linear multi-class classifier based on the multi-prototype version of 206b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua// passive aggressive. 216b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 226b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#include "native/multiclass_pa.h" 236b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 24a08525ea290ff4edc766eda1ec80388be866a79eDan Albert#include <stdlib.h> 25a08525ea290ff4edc766eda1ec80388be866a79eDan Albert 266b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huausing std::vector; 276b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huausing std::pair; 286b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 296b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huanamespace learningfw { 306b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 316b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huafloat RandFloat() { 326b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return static_cast<float>(rand()) / RAND_MAX; 336b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua} 346b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 356b4eebc73439cbc3ddfb547444a341d1f9be7996Wei HuaMulticlassPA::MulticlassPA(int num_classes, 366b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua int num_dimensions, 376b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua float aggressiveness) 386b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua : num_classes_(num_classes), 396b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua num_dimensions_(num_dimensions), 406b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua aggressiveness_(aggressiveness) { 416b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua InitializeParameters(); 426b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua} 436b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 446b4eebc73439cbc3ddfb547444a341d1f9be7996Wei HuaMulticlassPA::~MulticlassPA() { 456b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua} 466b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 476b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huavoid MulticlassPA::InitializeParameters() { 486b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua parameters_.resize(num_classes_); 496b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua for (int i = 0; i < num_classes_; ++i) { 506b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua parameters_[i].resize(num_dimensions_); 516b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua for (int j = 0; j < num_dimensions_; ++j) { 526b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua parameters_[i][j] = 0.0; 536b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 546b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 556b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua} 566b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 576b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huaint MulticlassPA::PickAClassExcept(int target) { 586b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua int picked; 596b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua do { 606b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua picked = static_cast<int>(RandFloat() * num_classes_); 616b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // picked = static_cast<int>(random_.RandFloat() * num_classes_); 626b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } while (target == picked); 636b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return picked; 646b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua} 656b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 666b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huaint MulticlassPA::PickAnExample(int num_examples) { 676b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return static_cast<int>(RandFloat() * num_examples); 686b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua} 696b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 706b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huafloat MulticlassPA::Score(const vector<float>& inputs, 716b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const vector<float>& parameters) const { 726b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // CHECK_EQ(inputs.size(), parameters.size()); 736b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua float result = 0.0; 746b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua for (int i = 0; i < static_cast<int>(inputs.size()); ++i) { 756b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua result += inputs[i] * parameters[i]; 766b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 776b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return result; 786b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua} 796b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 806b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huafloat MulticlassPA::SparseScore(const vector<pair<int, float> >& inputs, 816b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const vector<float>& parameters) const { 826b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua float result = 0.0; 836b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua for (int i = 0; i < static_cast<int>(inputs.size()); ++i) { 846b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua //DCHECK_GE(inputs[i].first, 0); 856b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua //DCHECK_LT(inputs[i].first, parameters.size()); 866b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua result += inputs[i].second * parameters[inputs[i].first]; 876b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 886b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return result; 896b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua} 906b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 916b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huafloat MulticlassPA::L2NormSquare(const vector<float>& inputs) const { 926b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua float norm = 0; 936b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua for (int i = 0; i < static_cast<int>(inputs.size()); ++i) { 946b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua norm += inputs[i] * inputs[i]; 956b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 966b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return norm; 976b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua} 986b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 996b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huafloat MulticlassPA::SparseL2NormSquare( 1006b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const vector<pair<int, float> >& inputs) const { 1016b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua float norm = 0; 1026b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua for (int i = 0; i < static_cast<int>(inputs.size()); ++i) { 1036b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua norm += inputs[i].second * inputs[i].second; 1046b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1056b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return norm; 1066b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua} 1076b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 1086b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huafloat MulticlassPA::TrainOneExample(const vector<float>& inputs, int target) { 1096b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua //CHECK_GE(target, 0); 1106b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua //CHECK_LT(target, num_classes_); 1116b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua float target_class_score = Score(inputs, parameters_[target]); 1126b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // VLOG(1) << "target class " << target << " score " << target_class_score; 1136b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua int other_class = PickAClassExcept(target); 1146b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua float other_class_score = Score(inputs, parameters_[other_class]); 1156b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // VLOG(1) << "other class " << other_class << " score " << other_class_score; 1166b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua float loss = 1.0 - target_class_score + other_class_score; 1176b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua if (loss > 0.0) { 1186b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // Compute the learning rate according to PA-I. 1196b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua float twice_norm_square = L2NormSquare(inputs) * 2.0; 1206b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua if (twice_norm_square == 0.0) { 1216b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua twice_norm_square = kEpsilon; 1226b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1236b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua float rate = loss / twice_norm_square; 1246b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua if (rate > aggressiveness_) { 1256b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua rate = aggressiveness_; 1266b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1276b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // VLOG(1) << "loss = " << loss << " rate = " << rate; 1286b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // Modify the parameter vectors of the correct and wrong classes 1296b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua for (int i = 0; i < static_cast<int>(inputs.size()); ++i) { 1306b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // First modify the parameter value of the correct class 1316b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua parameters_[target][i] += rate * inputs[i]; 1326b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // Then modify the parameter value of the wrong class 1336b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua parameters_[other_class][i] -= rate * inputs[i]; 1346b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1356b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return loss; 1366b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1376b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return 0.0; 1386b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua} 1396b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 1406b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huafloat MulticlassPA::SparseTrainOneExample( 1416b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const vector<pair<int, float> >& inputs, int target) { 1426b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // CHECK_GE(target, 0); 1436b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // CHECK_LT(target, num_classes_); 1446b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua float target_class_score = SparseScore(inputs, parameters_[target]); 1456b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // VLOG(1) << "target class " << target << " score " << target_class_score; 1466b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua int other_class = PickAClassExcept(target); 1476b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua float other_class_score = SparseScore(inputs, parameters_[other_class]); 1486b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // VLOG(1) << "other class " << other_class << " score " << other_class_score; 1496b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua float loss = 1.0 - target_class_score + other_class_score; 1506b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua if (loss > 0.0) { 1516b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // Compute the learning rate according to PA-I. 1526b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua float twice_norm_square = SparseL2NormSquare(inputs) * 2.0; 1536b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua if (twice_norm_square == 0.0) { 1546b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua twice_norm_square = kEpsilon; 1556b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1566b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua float rate = loss / twice_norm_square; 1576b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua if (rate > aggressiveness_) { 1586b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua rate = aggressiveness_; 1596b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1606b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // VLOG(1) << "loss = " << loss << " rate = " << rate; 1616b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // Modify the parameter vectors of the correct and wrong classes 1626b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua for (int i = 0; i < static_cast<int>(inputs.size()); ++i) { 1636b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // First modify the parameter value of the correct class 1646b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua parameters_[target][inputs[i].first] += rate * inputs[i].second; 1656b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // Then modify the parameter value of the wrong class 1666b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua parameters_[other_class][inputs[i].first] -= rate * inputs[i].second; 1676b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1686b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return loss; 1696b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1706b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return 0.0; 1716b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua} 1726b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 1736b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huafloat MulticlassPA::Train(const vector<pair<vector<float>, int> >& data, 1746b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua int num_iterations) { 1756b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua int num_examples = data.size(); 1766b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua float total_loss = 0.0; 1776b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua for (int t = 0; t < num_iterations; ++t) { 1786b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua int index = PickAnExample(num_examples); 1796b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua float loss_t = TrainOneExample(data[index].first, data[index].second); 1806b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua total_loss += loss_t; 1816b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1826b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return total_loss / static_cast<float>(num_iterations); 1836b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua} 1846b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 1856b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huafloat MulticlassPA::SparseTrain( 1866b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const vector<pair<vector<pair<int, float> >, int> >& data, 1876b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua int num_iterations) { 1886b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua int num_examples = data.size(); 1896b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua float total_loss = 0.0; 1906b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua for (int t = 0; t < num_iterations; ++t) { 1916b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua int index = PickAnExample(num_examples); 1926b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua float loss_t = SparseTrainOneExample(data[index].first, data[index].second); 1936b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua total_loss += loss_t; 1946b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1956b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return total_loss / static_cast<float>(num_iterations); 1966b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua} 1976b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 1986b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huaint MulticlassPA::GetClass(const vector<float>& inputs) { 1996b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua int best_class = -1; 2006b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua float best_score = -10000.0; 2016b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // float best_score = -MathLimits<float>::kMax; 2026b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua for (int i = 0; i < num_classes_; ++i) { 2036b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua float score_i = Score(inputs, parameters_[i]); 2046b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua if (score_i > best_score) { 2056b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua best_score = score_i; 2066b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua best_class = i; 2076b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 2086b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 2096b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return best_class; 2106b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua} 2116b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 2126b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huaint MulticlassPA::SparseGetClass(const vector<pair<int, float> >& inputs) { 2136b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua int best_class = -1; 2146b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua float best_score = -10000.0; 2156b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua //float best_score = -MathLimits<float>::kMax; 2166b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua for (int i = 0; i < num_classes_; ++i) { 2176b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua float score_i = SparseScore(inputs, parameters_[i]); 2186b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua if (score_i > best_score) { 2196b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua best_score = score_i; 2206b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua best_class = i; 2216b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 2226b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 2236b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return best_class; 2246b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua} 2256b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 2266b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huafloat MulticlassPA::Test(const vector<pair<vector<float>, int> >& data) { 2276b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua int num_examples = data.size(); 2286b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua float total_error = 0.0; 2296b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua for (int t = 0; t < num_examples; ++t) { 2306b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua int best_class = GetClass(data[t].first); 2316b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua if (best_class != data[t].second) { 2326b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua ++total_error; 2336b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 2346b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 2356b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return total_error / num_examples; 2366b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua} 2376b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 2386b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huafloat MulticlassPA::SparseTest( 2396b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const vector<pair<vector<pair<int, float> >, int> >& data) { 2406b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua int num_examples = data.size(); 2416b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua float total_error = 0.0; 2426b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua for (int t = 0; t < num_examples; ++t) { 2436b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua int best_class = SparseGetClass(data[t].first); 2446b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua if (best_class != data[t].second) { 2456b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua ++total_error; 2466b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 2476b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 2486b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return total_error / num_examples; 2496b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua} 2506b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua} // namespace learningfw 251