16b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua/*
26b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * Copyright (C) 2012 The Android Open Source Project
36b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua *
46b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * Licensed under the Apache License, Version 2.0 (the "License");
56b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * you may not use this file except in compliance with the License.
66b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * You may obtain a copy of the License at
76b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua *
86b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua *      http://www.apache.org/licenses/LICENSE-2.0
96b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua *
106b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * Unless required by applicable law or agreed to in writing, software
116b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * distributed under the License is distributed on an "AS IS" BASIS,
126b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
136b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * See the License for the specific language governing permissions and
146b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * limitations under the License.
156b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua */
166b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
176b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua//
186b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua// This file contains the MulticlassPA class which implements a simple
196b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua// linear multi-class classifier based on the multi-prototype version of
206b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua// passive aggressive.
216b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
226b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#include "native/multiclass_pa.h"
236b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
24a08525ea290ff4edc766eda1ec80388be866a79eDan Albert#include <stdlib.h>
25a08525ea290ff4edc766eda1ec80388be866a79eDan Albert
266b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huausing std::vector;
276b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huausing std::pair;
286b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
296b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huanamespace learningfw {
306b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
316b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huafloat RandFloat() {
326b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  return static_cast<float>(rand()) / RAND_MAX;
336b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua}
346b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
356b4eebc73439cbc3ddfb547444a341d1f9be7996Wei HuaMulticlassPA::MulticlassPA(int num_classes,
366b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua                           int num_dimensions,
376b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua                           float aggressiveness)
386b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    : num_classes_(num_classes),
396b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      num_dimensions_(num_dimensions),
406b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      aggressiveness_(aggressiveness) {
416b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  InitializeParameters();
426b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua}
436b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
446b4eebc73439cbc3ddfb547444a341d1f9be7996Wei HuaMulticlassPA::~MulticlassPA() {
456b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua}
466b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
476b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huavoid MulticlassPA::InitializeParameters() {
486b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  parameters_.resize(num_classes_);
496b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  for (int i = 0; i < num_classes_; ++i) {
506b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    parameters_[i].resize(num_dimensions_);
516b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    for (int j = 0; j < num_dimensions_; ++j) {
526b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      parameters_[i][j] = 0.0;
536b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    }
546b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
556b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua}
566b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
576b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huaint MulticlassPA::PickAClassExcept(int target) {
586b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  int picked;
596b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  do {
606b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    picked = static_cast<int>(RandFloat() * num_classes_);
616b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    //    picked = static_cast<int>(random_.RandFloat() * num_classes_);
626b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  } while (target == picked);
636b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  return picked;
646b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua}
656b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
666b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huaint MulticlassPA::PickAnExample(int num_examples) {
676b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  return static_cast<int>(RandFloat() * num_examples);
686b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua}
696b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
706b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huafloat MulticlassPA::Score(const vector<float>& inputs,
716b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua                          const vector<float>& parameters) const {
726b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // CHECK_EQ(inputs.size(), parameters.size());
736b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  float result = 0.0;
746b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  for (int i = 0; i < static_cast<int>(inputs.size()); ++i) {
756b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    result += inputs[i] * parameters[i];
766b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
776b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  return result;
786b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua}
796b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
806b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huafloat MulticlassPA::SparseScore(const vector<pair<int, float> >& inputs,
816b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua                                const vector<float>& parameters) const {
826b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  float result = 0.0;
836b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  for (int i = 0; i < static_cast<int>(inputs.size()); ++i) {
846b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    //DCHECK_GE(inputs[i].first, 0);
856b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    //DCHECK_LT(inputs[i].first, parameters.size());
866b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    result += inputs[i].second * parameters[inputs[i].first];
876b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
886b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  return result;
896b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua}
906b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
916b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huafloat MulticlassPA::L2NormSquare(const vector<float>& inputs) const {
926b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  float norm = 0;
936b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  for (int i = 0; i < static_cast<int>(inputs.size()); ++i) {
946b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    norm += inputs[i] * inputs[i];
956b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
966b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  return norm;
976b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua}
986b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
996b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huafloat MulticlassPA::SparseL2NormSquare(
1006b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    const vector<pair<int, float> >& inputs) const {
1016b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  float norm = 0;
1026b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  for (int i = 0; i < static_cast<int>(inputs.size()); ++i) {
1036b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    norm += inputs[i].second * inputs[i].second;
1046b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
1056b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  return norm;
1066b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua}
1076b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
1086b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huafloat MulticlassPA::TrainOneExample(const vector<float>& inputs, int target) {
1096b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  //CHECK_GE(target, 0);
1106b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  //CHECK_LT(target, num_classes_);
1116b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  float target_class_score = Score(inputs, parameters_[target]);
1126b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  //  VLOG(1) << "target class " << target << " score " << target_class_score;
1136b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  int other_class = PickAClassExcept(target);
1146b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  float other_class_score = Score(inputs, parameters_[other_class]);
1156b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  //  VLOG(1) << "other class " << other_class << " score " << other_class_score;
1166b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  float loss = 1.0 - target_class_score + other_class_score;
1176b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  if (loss > 0.0) {
1186b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    // Compute the learning rate according to PA-I.
1196b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    float twice_norm_square = L2NormSquare(inputs) * 2.0;
1206b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    if (twice_norm_square == 0.0) {
1216b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      twice_norm_square = kEpsilon;
1226b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    }
1236b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    float rate = loss / twice_norm_square;
1246b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    if (rate > aggressiveness_) {
1256b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      rate = aggressiveness_;
1266b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    }
1276b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    //    VLOG(1) << "loss = " << loss << " rate = " << rate;
1286b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    // Modify the parameter vectors of the correct and wrong classes
1296b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    for (int i = 0; i < static_cast<int>(inputs.size()); ++i) {
1306b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      // First modify the parameter value of the correct class
1316b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      parameters_[target][i] += rate * inputs[i];
1326b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      // Then modify the parameter value of the wrong class
1336b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      parameters_[other_class][i] -= rate * inputs[i];
1346b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    }
1356b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    return loss;
1366b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
1376b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  return 0.0;
1386b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua}
1396b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
1406b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huafloat MulticlassPA::SparseTrainOneExample(
1416b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    const vector<pair<int, float> >& inputs, int target) {
1426b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // CHECK_GE(target, 0);
1436b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // CHECK_LT(target, num_classes_);
1446b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  float target_class_score = SparseScore(inputs, parameters_[target]);
1456b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  //  VLOG(1) << "target class " << target << " score " << target_class_score;
1466b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  int other_class = PickAClassExcept(target);
1476b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  float other_class_score = SparseScore(inputs, parameters_[other_class]);
1486b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  //  VLOG(1) << "other class " << other_class << " score " << other_class_score;
1496b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  float loss = 1.0 - target_class_score + other_class_score;
1506b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  if (loss > 0.0) {
1516b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    // Compute the learning rate according to PA-I.
1526b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    float twice_norm_square = SparseL2NormSquare(inputs) * 2.0;
1536b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    if (twice_norm_square == 0.0) {
1546b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      twice_norm_square = kEpsilon;
1556b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    }
1566b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    float rate = loss / twice_norm_square;
1576b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    if (rate > aggressiveness_) {
1586b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      rate = aggressiveness_;
1596b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    }
1606b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    //    VLOG(1) << "loss = " << loss << " rate = " << rate;
1616b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    // Modify the parameter vectors of the correct and wrong classes
1626b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    for (int i = 0; i < static_cast<int>(inputs.size()); ++i) {
1636b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      // First modify the parameter value of the correct class
1646b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      parameters_[target][inputs[i].first] += rate * inputs[i].second;
1656b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      // Then modify the parameter value of the wrong class
1666b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      parameters_[other_class][inputs[i].first] -= rate * inputs[i].second;
1676b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    }
1686b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    return loss;
1696b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
1706b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  return 0.0;
1716b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua}
1726b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
1736b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huafloat MulticlassPA::Train(const vector<pair<vector<float>, int> >& data,
1746b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua                          int num_iterations) {
1756b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  int num_examples = data.size();
1766b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  float total_loss = 0.0;
1776b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  for (int t = 0; t < num_iterations; ++t) {
1786b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    int index = PickAnExample(num_examples);
1796b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    float loss_t = TrainOneExample(data[index].first, data[index].second);
1806b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    total_loss += loss_t;
1816b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
1826b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  return total_loss / static_cast<float>(num_iterations);
1836b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua}
1846b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
1856b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huafloat MulticlassPA::SparseTrain(
1866b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    const vector<pair<vector<pair<int, float> >, int> >& data,
1876b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    int num_iterations) {
1886b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  int num_examples = data.size();
1896b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  float total_loss = 0.0;
1906b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  for (int t = 0; t < num_iterations; ++t) {
1916b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    int index = PickAnExample(num_examples);
1926b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    float loss_t = SparseTrainOneExample(data[index].first, data[index].second);
1936b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    total_loss += loss_t;
1946b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
1956b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  return total_loss / static_cast<float>(num_iterations);
1966b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua}
1976b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
1986b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huaint MulticlassPA::GetClass(const vector<float>& inputs) {
1996b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  int best_class = -1;
2006b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  float best_score = -10000.0;
2016b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // float best_score = -MathLimits<float>::kMax;
2026b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  for (int i = 0; i < num_classes_; ++i) {
2036b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    float score_i = Score(inputs, parameters_[i]);
2046b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    if (score_i > best_score) {
2056b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      best_score = score_i;
2066b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      best_class = i;
2076b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    }
2086b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
2096b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  return best_class;
2106b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua}
2116b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
2126b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huaint MulticlassPA::SparseGetClass(const vector<pair<int, float> >& inputs) {
2136b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  int best_class = -1;
2146b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  float best_score = -10000.0;
2156b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  //float best_score = -MathLimits<float>::kMax;
2166b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  for (int i = 0; i < num_classes_; ++i) {
2176b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    float score_i = SparseScore(inputs, parameters_[i]);
2186b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    if (score_i > best_score) {
2196b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      best_score = score_i;
2206b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      best_class = i;
2216b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    }
2226b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
2236b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  return best_class;
2246b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua}
2256b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
2266b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huafloat MulticlassPA::Test(const vector<pair<vector<float>, int> >& data) {
2276b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  int num_examples = data.size();
2286b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  float total_error = 0.0;
2296b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  for (int t = 0; t < num_examples; ++t) {
2306b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    int best_class = GetClass(data[t].first);
2316b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    if (best_class != data[t].second) {
2326b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      ++total_error;
2336b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    }
2346b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
2356b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  return total_error / num_examples;
2366b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua}
2376b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
2386b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huafloat MulticlassPA::SparseTest(
2396b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    const vector<pair<vector<pair<int, float> >, int> >& data) {
2406b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  int num_examples = data.size();
2416b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  float total_error = 0.0;
2426b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  for (int t = 0; t < num_examples; ++t) {
2436b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    int best_class = SparseGetClass(data[t].first);
2446b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    if (best_class != data[t].second) {
2456b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      ++total_error;
2466b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    }
2476b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
2486b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  return total_error / num_examples;
2496b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua}
2506b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua}  // namespace learningfw
251