multiclass_pa.cpp revision 6b4eebc73439cbc3ddfb547444a341d1f9be7996
16b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua/*
26b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * Copyright (C) 2012 The Android Open Source Project
36b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua *
46b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * Licensed under the Apache License, Version 2.0 (the "License");
56b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * you may not use this file except in compliance with the License.
66b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * You may obtain a copy of the License at
76b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua *
86b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua *      http://www.apache.org/licenses/LICENSE-2.0
96b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua *
106b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * Unless required by applicable law or agreed to in writing, software
116b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * distributed under the License is distributed on an "AS IS" BASIS,
126b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
136b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * See the License for the specific language governing permissions and
146b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * limitations under the License.
156b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua */
166b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
176b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua//
186b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua// This file contains the MulticlassPA class which implements a simple
196b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua// linear multi-class classifier based on the multi-prototype version of
206b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua// passive aggressive.
216b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
226b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#include "native/multiclass_pa.h"
236b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
246b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huausing std::vector;
256b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huausing std::pair;
266b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
276b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huanamespace learningfw {
286b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
296b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huafloat RandFloat() {
306b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  return static_cast<float>(rand()) / RAND_MAX;
316b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua}
326b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
336b4eebc73439cbc3ddfb547444a341d1f9be7996Wei HuaMulticlassPA::MulticlassPA(int num_classes,
346b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua                           int num_dimensions,
356b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua                           float aggressiveness)
366b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    : num_classes_(num_classes),
376b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      num_dimensions_(num_dimensions),
386b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      aggressiveness_(aggressiveness) {
396b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  InitializeParameters();
406b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua}
416b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
426b4eebc73439cbc3ddfb547444a341d1f9be7996Wei HuaMulticlassPA::~MulticlassPA() {
436b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua}
446b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
456b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huavoid MulticlassPA::InitializeParameters() {
466b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  parameters_.resize(num_classes_);
476b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  for (int i = 0; i < num_classes_; ++i) {
486b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    parameters_[i].resize(num_dimensions_);
496b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    for (int j = 0; j < num_dimensions_; ++j) {
506b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      parameters_[i][j] = 0.0;
516b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    }
526b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
536b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua}
546b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
556b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huaint MulticlassPA::PickAClassExcept(int target) {
566b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  int picked;
576b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  do {
586b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    picked = static_cast<int>(RandFloat() * num_classes_);
596b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    //    picked = static_cast<int>(random_.RandFloat() * num_classes_);
606b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  } while (target == picked);
616b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  return picked;
626b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua}
636b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
646b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huaint MulticlassPA::PickAnExample(int num_examples) {
656b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  return static_cast<int>(RandFloat() * num_examples);
666b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua}
676b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
686b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huafloat MulticlassPA::Score(const vector<float>& inputs,
696b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua                          const vector<float>& parameters) const {
706b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // CHECK_EQ(inputs.size(), parameters.size());
716b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  float result = 0.0;
726b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  for (int i = 0; i < static_cast<int>(inputs.size()); ++i) {
736b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    result += inputs[i] * parameters[i];
746b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
756b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  return result;
766b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua}
776b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
786b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huafloat MulticlassPA::SparseScore(const vector<pair<int, float> >& inputs,
796b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua                                const vector<float>& parameters) const {
806b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  float result = 0.0;
816b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  for (int i = 0; i < static_cast<int>(inputs.size()); ++i) {
826b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    //DCHECK_GE(inputs[i].first, 0);
836b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    //DCHECK_LT(inputs[i].first, parameters.size());
846b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    result += inputs[i].second * parameters[inputs[i].first];
856b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
866b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  return result;
876b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua}
886b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
896b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huafloat MulticlassPA::L2NormSquare(const vector<float>& inputs) const {
906b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  float norm = 0;
916b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  for (int i = 0; i < static_cast<int>(inputs.size()); ++i) {
926b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    norm += inputs[i] * inputs[i];
936b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
946b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  return norm;
956b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua}
966b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
976b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huafloat MulticlassPA::SparseL2NormSquare(
986b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    const vector<pair<int, float> >& inputs) const {
996b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  float norm = 0;
1006b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  for (int i = 0; i < static_cast<int>(inputs.size()); ++i) {
1016b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    norm += inputs[i].second * inputs[i].second;
1026b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
1036b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  return norm;
1046b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua}
1056b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
1066b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huafloat MulticlassPA::TrainOneExample(const vector<float>& inputs, int target) {
1076b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  //CHECK_GE(target, 0);
1086b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  //CHECK_LT(target, num_classes_);
1096b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  float target_class_score = Score(inputs, parameters_[target]);
1106b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  //  VLOG(1) << "target class " << target << " score " << target_class_score;
1116b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  int other_class = PickAClassExcept(target);
1126b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  float other_class_score = Score(inputs, parameters_[other_class]);
1136b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  //  VLOG(1) << "other class " << other_class << " score " << other_class_score;
1146b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  float loss = 1.0 - target_class_score + other_class_score;
1156b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  if (loss > 0.0) {
1166b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    // Compute the learning rate according to PA-I.
1176b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    float twice_norm_square = L2NormSquare(inputs) * 2.0;
1186b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    if (twice_norm_square == 0.0) {
1196b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      twice_norm_square = kEpsilon;
1206b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    }
1216b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    float rate = loss / twice_norm_square;
1226b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    if (rate > aggressiveness_) {
1236b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      rate = aggressiveness_;
1246b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    }
1256b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    //    VLOG(1) << "loss = " << loss << " rate = " << rate;
1266b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    // Modify the parameter vectors of the correct and wrong classes
1276b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    for (int i = 0; i < static_cast<int>(inputs.size()); ++i) {
1286b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      // First modify the parameter value of the correct class
1296b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      parameters_[target][i] += rate * inputs[i];
1306b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      // Then modify the parameter value of the wrong class
1316b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      parameters_[other_class][i] -= rate * inputs[i];
1326b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    }
1336b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    return loss;
1346b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
1356b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  return 0.0;
1366b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua}
1376b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
1386b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huafloat MulticlassPA::SparseTrainOneExample(
1396b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    const vector<pair<int, float> >& inputs, int target) {
1406b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // CHECK_GE(target, 0);
1416b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // CHECK_LT(target, num_classes_);
1426b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  float target_class_score = SparseScore(inputs, parameters_[target]);
1436b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  //  VLOG(1) << "target class " << target << " score " << target_class_score;
1446b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  int other_class = PickAClassExcept(target);
1456b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  float other_class_score = SparseScore(inputs, parameters_[other_class]);
1466b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  //  VLOG(1) << "other class " << other_class << " score " << other_class_score;
1476b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  float loss = 1.0 - target_class_score + other_class_score;
1486b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  if (loss > 0.0) {
1496b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    // Compute the learning rate according to PA-I.
1506b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    float twice_norm_square = SparseL2NormSquare(inputs) * 2.0;
1516b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    if (twice_norm_square == 0.0) {
1526b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      twice_norm_square = kEpsilon;
1536b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    }
1546b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    float rate = loss / twice_norm_square;
1556b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    if (rate > aggressiveness_) {
1566b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      rate = aggressiveness_;
1576b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    }
1586b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    //    VLOG(1) << "loss = " << loss << " rate = " << rate;
1596b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    // Modify the parameter vectors of the correct and wrong classes
1606b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    for (int i = 0; i < static_cast<int>(inputs.size()); ++i) {
1616b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      // First modify the parameter value of the correct class
1626b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      parameters_[target][inputs[i].first] += rate * inputs[i].second;
1636b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      // Then modify the parameter value of the wrong class
1646b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      parameters_[other_class][inputs[i].first] -= rate * inputs[i].second;
1656b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    }
1666b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    return loss;
1676b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
1686b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  return 0.0;
1696b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua}
1706b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
1716b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huafloat MulticlassPA::Train(const vector<pair<vector<float>, int> >& data,
1726b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua                          int num_iterations) {
1736b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  int num_examples = data.size();
1746b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  float total_loss = 0.0;
1756b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  for (int t = 0; t < num_iterations; ++t) {
1766b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    int index = PickAnExample(num_examples);
1776b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    float loss_t = TrainOneExample(data[index].first, data[index].second);
1786b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    total_loss += loss_t;
1796b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
1806b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  return total_loss / static_cast<float>(num_iterations);
1816b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua}
1826b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
1836b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huafloat MulticlassPA::SparseTrain(
1846b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    const vector<pair<vector<pair<int, float> >, int> >& data,
1856b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    int num_iterations) {
1866b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  int num_examples = data.size();
1876b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  float total_loss = 0.0;
1886b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  for (int t = 0; t < num_iterations; ++t) {
1896b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    int index = PickAnExample(num_examples);
1906b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    float loss_t = SparseTrainOneExample(data[index].first, data[index].second);
1916b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    total_loss += loss_t;
1926b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
1936b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  return total_loss / static_cast<float>(num_iterations);
1946b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua}
1956b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
1966b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huaint MulticlassPA::GetClass(const vector<float>& inputs) {
1976b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  int best_class = -1;
1986b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  float best_score = -10000.0;
1996b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // float best_score = -MathLimits<float>::kMax;
2006b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  for (int i = 0; i < num_classes_; ++i) {
2016b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    float score_i = Score(inputs, parameters_[i]);
2026b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    if (score_i > best_score) {
2036b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      best_score = score_i;
2046b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      best_class = i;
2056b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    }
2066b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
2076b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  return best_class;
2086b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua}
2096b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
2106b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huaint MulticlassPA::SparseGetClass(const vector<pair<int, float> >& inputs) {
2116b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  int best_class = -1;
2126b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  float best_score = -10000.0;
2136b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  //float best_score = -MathLimits<float>::kMax;
2146b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  for (int i = 0; i < num_classes_; ++i) {
2156b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    float score_i = SparseScore(inputs, parameters_[i]);
2166b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    if (score_i > best_score) {
2176b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      best_score = score_i;
2186b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      best_class = i;
2196b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    }
2206b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
2216b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  return best_class;
2226b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua}
2236b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
2246b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huafloat MulticlassPA::Test(const vector<pair<vector<float>, int> >& data) {
2256b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  int num_examples = data.size();
2266b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  float total_error = 0.0;
2276b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  for (int t = 0; t < num_examples; ++t) {
2286b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    int best_class = GetClass(data[t].first);
2296b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    if (best_class != data[t].second) {
2306b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      ++total_error;
2316b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    }
2326b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
2336b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  return total_error / num_examples;
2346b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua}
2356b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
2366b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huafloat MulticlassPA::SparseTest(
2376b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    const vector<pair<vector<pair<int, float> >, int> >& data) {
2386b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  int num_examples = data.size();
2396b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  float total_error = 0.0;
2406b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  for (int t = 0; t < num_examples; ++t) {
2416b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    int best_class = SparseGetClass(data[t].first);
2426b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    if (best_class != data[t].second) {
2436b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      ++total_error;
2446b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    }
2456b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
2466b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  return total_error / num_examples;
2476b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua}
2486b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua}  // namespace learningfw
249