Learning_StochasticLinearRanker.java revision b019e89cbea221598c482b05ab68b7660b41aa23
16b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua/*
26b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * Copyright (C) 2012 The Android Open Source Project
36b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua *
46b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * Licensed under the Apache License, Version 2.0 (the "License");
56b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * you may not use this file except in compliance with the License.
66b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * You may obtain a copy of the License at
76b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua *
86b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua *      http://www.apache.org/licenses/LICENSE-2.0
96b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua *
106b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * Unless required by applicable law or agreed to in writing, software
116b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * distributed under the License is distributed on an "AS IS" BASIS,
126b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
136b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * See the License for the specific language governing permissions and
146b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * limitations under the License.
156b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua */
166b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
176b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huapackage android.bordeaux.services;
186b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
196b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huaimport android.bordeaux.learning.StochasticLinearRanker;
201dd8ef56681617db46caec7776c9bf416f01d8ddWei Huaimport android.bordeaux.services.IBordeauxLearner.ModelChangeCallback;
211dd8ef56681617db46caec7776c9bf416f01d8ddWei Huaimport android.os.IBinder;
226b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huaimport android.util.Log;
23b019e89cbea221598c482b05ab68b7660b41aa23saberianimport java.util.List;
24b019e89cbea221598c482b05ab68b7660b41aa23saberianimport java.util.ArrayList;
256b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huaimport java.io.*;
261dd8ef56681617db46caec7776c9bf416f01d8ddWei Huaimport java.lang.ClassNotFoundException;
271dd8ef56681617db46caec7776c9bf416f01d8ddWei Huaimport java.util.Arrays;
281dd8ef56681617db46caec7776c9bf416f01d8ddWei Huaimport java.util.ArrayList;
291dd8ef56681617db46caec7776c9bf416f01d8ddWei Huaimport java.util.List;
306b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huaimport java.util.Scanner;
31b019e89cbea221598c482b05ab68b7660b41aa23saberianimport java.io.ByteArrayOutputStream;
32b019e89cbea221598c482b05ab68b7660b41aa23saberianimport java.util.HashMap;
33b019e89cbea221598c482b05ab68b7660b41aa23saberianimport java.util.Map;
346b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
351dd8ef56681617db46caec7776c9bf416f01d8ddWei Huapublic class Learning_StochasticLinearRanker extends ILearning_StochasticLinearRanker.Stub
361dd8ef56681617db46caec7776c9bf416f01d8ddWei Hua        implements IBordeauxLearner {
376b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
38b019e89cbea221598c482b05ab68b7660b41aa23saberian    private final String TAG = "ILearning_StochasticLinearRanker";
39b019e89cbea221598c482b05ab68b7660b41aa23saberian    private StochasticLinearRankerWithPrior mLearningSlRanker = null;
401dd8ef56681617db46caec7776c9bf416f01d8ddWei Hua    private ModelChangeCallback modelChangeCallback = null;
416b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
421dd8ef56681617db46caec7776c9bf416f01d8ddWei Hua    public Learning_StochasticLinearRanker(){
436b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    }
446b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
45b019e89cbea221598c482b05ab68b7660b41aa23saberian    public void ResetRanker(){
46b019e89cbea221598c482b05ab68b7660b41aa23saberian        if (mLearningSlRanker == null)
47b019e89cbea221598c482b05ab68b7660b41aa23saberian            mLearningSlRanker = new StochasticLinearRankerWithPrior();
48b019e89cbea221598c482b05ab68b7660b41aa23saberian        mLearningSlRanker.resetRanker();
49b019e89cbea221598c482b05ab68b7660b41aa23saberian    }
50b019e89cbea221598c482b05ab68b7660b41aa23saberian
516b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    public boolean UpdateClassifier(List<StringFloat> sample_1, List<StringFloat> sample_2){
526b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua        ArrayList<StringFloat> temp_1 = (ArrayList<StringFloat>)sample_1;
536b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua        String[] keys_1 = new String[temp_1.size()];
546b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua        float[] values_1 = new float[temp_1.size()];
556b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua        for (int i = 0; i < temp_1.size(); i++){
566b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua            keys_1[i] = temp_1.get(i).key;
576b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua            values_1[i] = temp_1.get(i).value;
586b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua        }
596b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua        ArrayList<StringFloat> temp_2 = (ArrayList<StringFloat>)sample_2;
606b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua        String[] keys_2 = new String[temp_2.size()];
616b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua        float[] values_2 = new float[temp_2.size()];
626b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua        for (int i = 0; i < temp_2.size(); i++){
636b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua            keys_2[i] = temp_2.get(i).key;
646b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua            values_2[i] = temp_2.get(i).value;
656b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua        }
66b019e89cbea221598c482b05ab68b7660b41aa23saberian        if (mLearningSlRanker == null)
67b019e89cbea221598c482b05ab68b7660b41aa23saberian            mLearningSlRanker = new StochasticLinearRankerWithPrior();
686b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua        boolean res = mLearningSlRanker.updateClassifier(keys_1,values_1,keys_2,values_2);
691dd8ef56681617db46caec7776c9bf416f01d8ddWei Hua        if (res && modelChangeCallback != null) {
701dd8ef56681617db46caec7776c9bf416f01d8ddWei Hua            modelChangeCallback.modelChanged(this);
711dd8ef56681617db46caec7776c9bf416f01d8ddWei Hua        }
726b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua        return res;
736b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    }
746b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
756b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    public float ScoreSample(List<StringFloat> sample) {
766b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua        ArrayList<StringFloat> temp = (ArrayList<StringFloat>)sample;
776b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua        String[] keys = new String[temp.size()];
786b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua        float[] values = new float[temp.size()];
796b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua        for (int i = 0; i < temp.size(); i++){
806b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua            keys[i] = temp.get(i).key;
816b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua            values[i] = temp.get(i).value;
826b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua        }
83b019e89cbea221598c482b05ab68b7660b41aa23saberian        if (mLearningSlRanker == null)
84b019e89cbea221598c482b05ab68b7660b41aa23saberian            mLearningSlRanker = new StochasticLinearRankerWithPrior();
85b019e89cbea221598c482b05ab68b7660b41aa23saberian        return mLearningSlRanker.scoreSample(keys,values);
86b019e89cbea221598c482b05ab68b7660b41aa23saberian    }
87b019e89cbea221598c482b05ab68b7660b41aa23saberian
88b019e89cbea221598c482b05ab68b7660b41aa23saberian    public boolean SetModelPriorWeight(List<StringFloat> sample) {
89b019e89cbea221598c482b05ab68b7660b41aa23saberian        ArrayList<StringFloat> temp = (ArrayList<StringFloat>)sample;
90b019e89cbea221598c482b05ab68b7660b41aa23saberian        HashMap<String, Float> weights = new HashMap<String, Float>();
91b019e89cbea221598c482b05ab68b7660b41aa23saberian        for (int i = 0; i < temp.size(); i++)
92b019e89cbea221598c482b05ab68b7660b41aa23saberian            weights.put(temp.get(i).key, temp.get(i).value);
93b019e89cbea221598c482b05ab68b7660b41aa23saberian        if (mLearningSlRanker == null)
94b019e89cbea221598c482b05ab68b7660b41aa23saberian            mLearningSlRanker = new StochasticLinearRankerWithPrior();
95b019e89cbea221598c482b05ab68b7660b41aa23saberian        return mLearningSlRanker.setModelPriorWeights(weights);
96b019e89cbea221598c482b05ab68b7660b41aa23saberian    }
97b019e89cbea221598c482b05ab68b7660b41aa23saberian
98b019e89cbea221598c482b05ab68b7660b41aa23saberian    public boolean SetModelParameter(String key, String value) {
99b019e89cbea221598c482b05ab68b7660b41aa23saberian        if (mLearningSlRanker == null)
100b019e89cbea221598c482b05ab68b7660b41aa23saberian            mLearningSlRanker = new StochasticLinearRankerWithPrior();
101b019e89cbea221598c482b05ab68b7660b41aa23saberian        return mLearningSlRanker.setModelParameter(key,value);
1026b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    }
1036b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
1041dd8ef56681617db46caec7776c9bf416f01d8ddWei Hua    // Beginning of the IBordeauxLearner Interface implementation
1051dd8ef56681617db46caec7776c9bf416f01d8ddWei Hua    public byte [] getModel() {
106b019e89cbea221598c482b05ab68b7660b41aa23saberian        if (mLearningSlRanker == null)
107b019e89cbea221598c482b05ab68b7660b41aa23saberian            mLearningSlRanker = new StochasticLinearRankerWithPrior();
108b019e89cbea221598c482b05ab68b7660b41aa23saberian        StochasticLinearRankerWithPrior.Model model = mLearningSlRanker.getModel();
1091dd8ef56681617db46caec7776c9bf416f01d8ddWei Hua        try {
1101dd8ef56681617db46caec7776c9bf416f01d8ddWei Hua            ByteArrayOutputStream byteStream = new ByteArrayOutputStream();
1111dd8ef56681617db46caec7776c9bf416f01d8ddWei Hua            ObjectOutputStream objStream = new ObjectOutputStream(byteStream);
1121dd8ef56681617db46caec7776c9bf416f01d8ddWei Hua            objStream.writeObject(model);
1131dd8ef56681617db46caec7776c9bf416f01d8ddWei Hua            //return byteStream.toByteArray();
1141dd8ef56681617db46caec7776c9bf416f01d8ddWei Hua            byte[] bytes = byteStream.toByteArray();
1151dd8ef56681617db46caec7776c9bf416f01d8ddWei Hua            Log.i(TAG, "getModel: " + bytes);
1161dd8ef56681617db46caec7776c9bf416f01d8ddWei Hua            return bytes;
1171dd8ef56681617db46caec7776c9bf416f01d8ddWei Hua        } catch (IOException e) {
1181dd8ef56681617db46caec7776c9bf416f01d8ddWei Hua            throw new RuntimeException("Can't get model");
1196b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua        }
1206b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    }
1216b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
1221dd8ef56681617db46caec7776c9bf416f01d8ddWei Hua    public boolean setModel(final byte [] modelData) {
1231dd8ef56681617db46caec7776c9bf416f01d8ddWei Hua        try {
1241dd8ef56681617db46caec7776c9bf416f01d8ddWei Hua            ByteArrayInputStream input = new ByteArrayInputStream(modelData);
1251dd8ef56681617db46caec7776c9bf416f01d8ddWei Hua            ObjectInputStream objStream = new ObjectInputStream(input);
126b019e89cbea221598c482b05ab68b7660b41aa23saberian            StochasticLinearRankerWithPrior.Model model =
127b019e89cbea221598c482b05ab68b7660b41aa23saberian                    (StochasticLinearRankerWithPrior.Model) objStream.readObject();
128b019e89cbea221598c482b05ab68b7660b41aa23saberian            if (mLearningSlRanker == null)
129b019e89cbea221598c482b05ab68b7660b41aa23saberian                mLearningSlRanker = new StochasticLinearRankerWithPrior();
1301dd8ef56681617db46caec7776c9bf416f01d8ddWei Hua            boolean res = mLearningSlRanker.loadModel(model);
1311dd8ef56681617db46caec7776c9bf416f01d8ddWei Hua            Log.i(TAG, "LoadModel: " + modelData);
1321dd8ef56681617db46caec7776c9bf416f01d8ddWei Hua            return res;
1331dd8ef56681617db46caec7776c9bf416f01d8ddWei Hua        } catch (IOException e) {
1341dd8ef56681617db46caec7776c9bf416f01d8ddWei Hua            throw new RuntimeException("Can't load model");
1351dd8ef56681617db46caec7776c9bf416f01d8ddWei Hua        } catch (ClassNotFoundException e) {
1361dd8ef56681617db46caec7776c9bf416f01d8ddWei Hua            throw new RuntimeException("Learning class not found");
1376b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua        }
1386b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    }
1391dd8ef56681617db46caec7776c9bf416f01d8ddWei Hua
1401dd8ef56681617db46caec7776c9bf416f01d8ddWei Hua    public IBinder getBinder() {
1411dd8ef56681617db46caec7776c9bf416f01d8ddWei Hua        return this;
1421dd8ef56681617db46caec7776c9bf416f01d8ddWei Hua    }
1431dd8ef56681617db46caec7776c9bf416f01d8ddWei Hua
1441dd8ef56681617db46caec7776c9bf416f01d8ddWei Hua    public void setModelChangeCallback(ModelChangeCallback callback) {
1451dd8ef56681617db46caec7776c9bf416f01d8ddWei Hua        modelChangeCallback = callback;
1461dd8ef56681617db46caec7776c9bf416f01d8ddWei Hua    }
1471dd8ef56681617db46caec7776c9bf416f01d8ddWei Hua    // End of IBordeauxLearner Interface implemenation
1486b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua}
149