16b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua/* 26b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * Copyright (C) 2011 The Android Open Source Project 36b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * 46b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * Licensed under the Apache License, Version 2.0 (the "License"); 56b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * you may not use this file except in compliance with the License. 66b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * You may obtain a copy of the License at 76b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * 86b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * http://www.apache.org/licenses/LICENSE-2.0 96b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * 106b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * Unless required by applicable law or agreed to in writing, software 116b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * distributed under the License is distributed on an "AS IS" BASIS, 126b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 136b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * See the License for the specific language governing permissions and 146b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * limitations under the License. 156b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua */ 166b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 176b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 186b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huapackage android.bordeaux.learning; 191dd8ef56681617db46caec7776c9bf416f01d8ddWei Hua 206b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huaimport android.util.Log; 211dd8ef56681617db46caec7776c9bf416f01d8ddWei Hua 221dd8ef56681617db46caec7776c9bf416f01d8ddWei Huaimport java.io.Serializable; 236b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huaimport java.util.List; 246b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huaimport java.util.Arrays; 256b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huaimport java.util.ArrayList; 26b019e89cbea221598c482b05ab68b7660b41aa23saberianimport java.util.HashMap; 27b019e89cbea221598c482b05ab68b7660b41aa23saberianimport java.util.Map; 286b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 296b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua/** 306b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * Stochastic Linear Ranker, learns how to rank a sample. The learned rank score 316b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * can be used to compare samples. 326b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * This java class wraps the native StochasticLinearRanker class. 336b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * To update the ranker, call updateClassifier with two samples, with the first 346b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * one having higher rank than the second one. 356b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * To get the rank score of the sample call scoreSample. 366b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * TODO: adding more interfaces for changing the learning parameters 376b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua */ 386b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huapublic class StochasticLinearRanker { 396b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua String TAG = "StochasticLinearRanker"; 40b019e89cbea221598c482b05ab68b7660b41aa23saberian public static int VAR_NUM = 14; 411dd8ef56681617db46caec7776c9bf416f01d8ddWei Hua static public class Model implements Serializable { 42b019e89cbea221598c482b05ab68b7660b41aa23saberian public HashMap<String, Float> weights = new HashMap<String, Float>(); 43b019e89cbea221598c482b05ab68b7660b41aa23saberian public float weightNormalizer = 1; 44b019e89cbea221598c482b05ab68b7660b41aa23saberian public HashMap<String, String> parameters = new HashMap<String, String>(); 451dd8ef56681617db46caec7776c9bf416f01d8ddWei Hua } 461dd8ef56681617db46caec7776c9bf416f01d8ddWei Hua 47b019e89cbea221598c482b05ab68b7660b41aa23saberian /** 48b019e89cbea221598c482b05ab68b7660b41aa23saberian * Initializing a ranker 49b019e89cbea221598c482b05ab68b7660b41aa23saberian */ 506b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua public StochasticLinearRanker() { 516b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua mNativeClassifier = initNativeClassifier(); 526b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 536b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 546b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua /** 55b019e89cbea221598c482b05ab68b7660b41aa23saberian * Reset the ranker 56b019e89cbea221598c482b05ab68b7660b41aa23saberian */ 57b019e89cbea221598c482b05ab68b7660b41aa23saberian public void resetRanker(){ 58b019e89cbea221598c482b05ab68b7660b41aa23saberian deleteNativeClassifier(mNativeClassifier); 59b019e89cbea221598c482b05ab68b7660b41aa23saberian mNativeClassifier = initNativeClassifier(); 60b019e89cbea221598c482b05ab68b7660b41aa23saberian } 61b019e89cbea221598c482b05ab68b7660b41aa23saberian 62b019e89cbea221598c482b05ab68b7660b41aa23saberian /** 636b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * Train the ranker with a pair of samples. A sample, a pair of arrays of 646b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * keys and values. The first sample should have higher rank than the second 656b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * one. 666b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua */ 676b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua public boolean updateClassifier(String[] keys_positive, 686b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua float[] values_positive, 696b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua String[] keys_negative, 706b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua float[] values_negative) { 716b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return nativeUpdateClassifier(keys_positive, values_positive, 726b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua keys_negative, values_negative, 736b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua mNativeClassifier); 746b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 756b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 766b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua /** 771dd8ef56681617db46caec7776c9bf416f01d8ddWei Hua * Get the rank score of the sample, a sample is a list of key, value pairs. 786b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua */ 796b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua public float scoreSample(String[] keys, float[] values) { 806b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return nativeScoreSample(keys, values, mNativeClassifier); 816b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 826b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 836b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua /** 846b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * Get the current model and parameters of ranker 856b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua */ 86b019e89cbea221598c482b05ab68b7660b41aa23saberian public Model getUModel(){ 87b019e89cbea221598c482b05ab68b7660b41aa23saberian Model slrModel = new Model(); 886b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua int len = nativeGetLengthClassifier(mNativeClassifier); 89b019e89cbea221598c482b05ab68b7660b41aa23saberian String[] wKeys = new String[len]; 90b019e89cbea221598c482b05ab68b7660b41aa23saberian float[] wValues = new float[len]; 91b019e89cbea221598c482b05ab68b7660b41aa23saberian float wNormalizer = 1; 92b019e89cbea221598c482b05ab68b7660b41aa23saberian nativeGetWeightClassifier(wKeys, wValues, wNormalizer, mNativeClassifier); 93b019e89cbea221598c482b05ab68b7660b41aa23saberian slrModel.weightNormalizer = wNormalizer; 94b019e89cbea221598c482b05ab68b7660b41aa23saberian for (int i=0; i< wKeys.length ; i++) 95b019e89cbea221598c482b05ab68b7660b41aa23saberian slrModel.weights.put(wKeys[i], wValues[i]); 96b019e89cbea221598c482b05ab68b7660b41aa23saberian 97b019e89cbea221598c482b05ab68b7660b41aa23saberian String[] paramKeys = new String[VAR_NUM]; 98b019e89cbea221598c482b05ab68b7660b41aa23saberian String[] paramValues = new String[VAR_NUM]; 99b019e89cbea221598c482b05ab68b7660b41aa23saberian nativeGetParameterClassifier(paramKeys, paramValues, mNativeClassifier); 100b019e89cbea221598c482b05ab68b7660b41aa23saberian for (int i=0; i< paramKeys.length ; i++) 101b019e89cbea221598c482b05ab68b7660b41aa23saberian slrModel.parameters.put(paramKeys[i], paramValues[i]); 102b019e89cbea221598c482b05ab68b7660b41aa23saberian return slrModel; 1036b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1046b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 1056b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua /** 106b019e89cbea221598c482b05ab68b7660b41aa23saberian * load the given model and parameters to the ranker 1076b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua */ 1081dd8ef56681617db46caec7776c9bf416f01d8ddWei Hua public boolean loadModel(Model model) { 109b019e89cbea221598c482b05ab68b7660b41aa23saberian String[] wKeys = new String[model.weights.size()]; 110b019e89cbea221598c482b05ab68b7660b41aa23saberian float[] wValues = new float[model.weights.size()]; 111b019e89cbea221598c482b05ab68b7660b41aa23saberian int i = 0 ; 112b019e89cbea221598c482b05ab68b7660b41aa23saberian for (Map.Entry<String, Float> e : model.weights.entrySet()){ 113b019e89cbea221598c482b05ab68b7660b41aa23saberian wKeys[i] = e.getKey(); 114b019e89cbea221598c482b05ab68b7660b41aa23saberian wValues[i] = e.getValue(); 115b019e89cbea221598c482b05ab68b7660b41aa23saberian i++; 1161dd8ef56681617db46caec7776c9bf416f01d8ddWei Hua } 117b019e89cbea221598c482b05ab68b7660b41aa23saberian boolean res = setModelWeights(wKeys, wValues, model.weightNormalizer); 118b019e89cbea221598c482b05ab68b7660b41aa23saberian if (!res) 119b019e89cbea221598c482b05ab68b7660b41aa23saberian return false; 120b019e89cbea221598c482b05ab68b7660b41aa23saberian 121b019e89cbea221598c482b05ab68b7660b41aa23saberian for (Map.Entry<String, String> e : model.parameters.entrySet()){ 122b019e89cbea221598c482b05ab68b7660b41aa23saberian res = setModelParameter(e.getKey(), e.getValue()); 123b019e89cbea221598c482b05ab68b7660b41aa23saberian if (!res) 124b019e89cbea221598c482b05ab68b7660b41aa23saberian return false; 1251dd8ef56681617db46caec7776c9bf416f01d8ddWei Hua } 126b019e89cbea221598c482b05ab68b7660b41aa23saberian return res; 127b019e89cbea221598c482b05ab68b7660b41aa23saberian } 128b019e89cbea221598c482b05ab68b7660b41aa23saberian 129b019e89cbea221598c482b05ab68b7660b41aa23saberian public boolean setModelWeights(String[] keys, float [] values, float normalizer){ 130b019e89cbea221598c482b05ab68b7660b41aa23saberian return nativeSetWeightClassifier(keys, values, normalizer, mNativeClassifier); 131b019e89cbea221598c482b05ab68b7660b41aa23saberian } 132b019e89cbea221598c482b05ab68b7660b41aa23saberian 133b019e89cbea221598c482b05ab68b7660b41aa23saberian public boolean setModelParameter(String key, String value){ 134b019e89cbea221598c482b05ab68b7660b41aa23saberian boolean res = nativeSetParameterClassifier(key, value, mNativeClassifier); 135b019e89cbea221598c482b05ab68b7660b41aa23saberian return res; 136b019e89cbea221598c482b05ab68b7660b41aa23saberian } 137b019e89cbea221598c482b05ab68b7660b41aa23saberian 138b019e89cbea221598c482b05ab68b7660b41aa23saberian /** 139b019e89cbea221598c482b05ab68b7660b41aa23saberian * Print a model for debugging 140b019e89cbea221598c482b05ab68b7660b41aa23saberian */ 141b019e89cbea221598c482b05ab68b7660b41aa23saberian public void print(Model model){ 142b019e89cbea221598c482b05ab68b7660b41aa23saberian String Sw = ""; 143b019e89cbea221598c482b05ab68b7660b41aa23saberian String Sp = ""; 144b019e89cbea221598c482b05ab68b7660b41aa23saberian for (Map.Entry<String, Float> e : model.weights.entrySet()) 145b019e89cbea221598c482b05ab68b7660b41aa23saberian Sw = Sw + "<" + e.getKey() + "," + e.getValue() + "> "; 146b019e89cbea221598c482b05ab68b7660b41aa23saberian for (Map.Entry<String, String> e : model.parameters.entrySet()) 147b019e89cbea221598c482b05ab68b7660b41aa23saberian Sp = Sp + "<" + e.getKey() + "," + e.getValue() + "> "; 148b019e89cbea221598c482b05ab68b7660b41aa23saberian Log.i(TAG, "Weights are " + Sw); 149b019e89cbea221598c482b05ab68b7660b41aa23saberian Log.i(TAG, "Normalizer is " + model.weightNormalizer); 150b019e89cbea221598c482b05ab68b7660b41aa23saberian Log.i(TAG, "Parameters are " + Sp); 1516b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1526b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 1536b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua @Override 1546b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua protected void finalize() throws Throwable { 1556b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua deleteNativeClassifier(mNativeClassifier); 1566b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1576b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 1586b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua static { 1596b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua System.loadLibrary("bordeaux"); 1606b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1616b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 1626b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua private int mNativeClassifier; 1636b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 1646b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua /* 1656b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * The following methods are the java stubs for the jni implementations. 1666b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua */ 1676b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua private native int initNativeClassifier(); 1686b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 1696b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua private native void deleteNativeClassifier(int classifierPtr); 1706b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 1716b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua private native boolean nativeUpdateClassifier( 1726b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua String[] keys_positive, 1736b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua float[] values_positive, 1746b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua String[] keys_negative, 1756b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua float[] values_negative, 1766b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua int classifierPtr); 1776b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 178b019e89cbea221598c482b05ab68b7660b41aa23saberian private native float nativeScoreSample(String[] keys, float[] values, int classifierPtr); 179b019e89cbea221598c482b05ab68b7660b41aa23saberian 180b019e89cbea221598c482b05ab68b7660b41aa23saberian private native void nativeGetWeightClassifier(String [] keys, float[] values, float normalizer, 181b019e89cbea221598c482b05ab68b7660b41aa23saberian int classifierPtr); 182b019e89cbea221598c482b05ab68b7660b41aa23saberian 183b019e89cbea221598c482b05ab68b7660b41aa23saberian private native void nativeGetParameterClassifier(String [] keys, String[] values, 184b019e89cbea221598c482b05ab68b7660b41aa23saberian int classifierPtr); 185b019e89cbea221598c482b05ab68b7660b41aa23saberian 1866b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua private native int nativeGetLengthClassifier(int classifierPtr); 187b019e89cbea221598c482b05ab68b7660b41aa23saberian 188b019e89cbea221598c482b05ab68b7660b41aa23saberian private native boolean nativeSetWeightClassifier(String [] keys, float[] values, 189b019e89cbea221598c482b05ab68b7660b41aa23saberian float normalizer, int classifierPtr); 190b019e89cbea221598c482b05ab68b7660b41aa23saberian 191b019e89cbea221598c482b05ab68b7660b41aa23saberian private native boolean nativeSetParameterClassifier(String key, String value, 192b019e89cbea221598c482b05ab68b7660b41aa23saberian int classifierPtr); 1936b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua} 194