InstanceLearner.java revision 0a63716ed0e44f7cd32b81a444429318d42d8f08
1/* 2 * Copyright (C) 2008-2009 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17package android.gesture; 18 19import java.util.ArrayList; 20import java.util.Collections; 21import java.util.Comparator; 22import java.util.TreeMap; 23 24/** 25 * An implementation of an instance-based learner 26 */ 27 28class InstanceLearner extends Learner { 29 @Override 30 ArrayList<Prediction> classify(int sequenceType, float[] vector) { 31 ArrayList<Prediction> predictions = new ArrayList<Prediction>(); 32 ArrayList<Instance> instances = getInstances(); 33 int count = instances.size(); 34 TreeMap<String, Double> label2score = new TreeMap<String, Double>(); 35 for (int i = 0; i < count; i++) { 36 Instance sample = instances.get(i); 37 if (sample.vector.length != vector.length) { 38 continue; 39 } 40 double distance; 41 if (sequenceType == GestureStore.SEQUENCE_SENSITIVE) { 42 distance = GestureUtilities.cosineDistance(sample.vector, vector); 43 } else { 44 distance = GestureUtilities.squaredEuclideanDistance(sample.vector, vector); 45 } 46 double weight; 47 if (distance == 0) { 48 weight = Double.MAX_VALUE; 49 } else { 50 weight = 1 / distance; 51 } 52 Double score = label2score.get(sample.label); 53 if (score == null || weight > score) { 54 label2score.put(sample.label, weight); 55 } 56 } 57 58 double sum = 0; 59 for (String name : label2score.keySet()) { 60 double score = label2score.get(name); 61 sum += score; 62 predictions.add(new Prediction(name, score)); 63 } 64 65 // normalize 66 for (Prediction prediction : predictions) { 67 prediction.score /= sum; 68 } 69 70 Collections.sort(predictions, new Comparator<Prediction>() { 71 public int compare(Prediction object1, Prediction object2) { 72 double score1 = object1.score; 73 double score2 = object2.score; 74 if (score1 > score2) { 75 return -1; 76 } else if (score1 < score2) { 77 return 1; 78 } else { 79 return 0; 80 } 81 } 82 }); 83 84 return predictions; 85 } 86} 87