1package com.android.server.wifi;
2
3import static org.mockito.Matchers.anyInt;
4import static org.mockito.Matchers.anyString;
5import static org.mockito.Mockito.mock;
6import static org.mockito.Mockito.when;
7
8import android.security.KeyStore;
9import android.util.SparseArray;
10
11import org.mockito.Matchers;
12import org.mockito.invocation.InvocationOnMock;
13import org.mockito.stubbing.Answer;
14
15import java.util.Arrays;
16import java.util.HashMap;
17
18class MockKeyStore {
19
20    public static class KeyBlob {
21        public byte[] blob;
22        public int flag;
23
24        public void update(byte[] blob, int flag) {
25            this.blob = Arrays.copyOf(blob, blob.length);
26            this.flag = flag;
27        }
28    }
29    private SparseArray<HashMap<String, KeyBlob>> mStore;
30
31    public MockKeyStore() {
32        mStore = new SparseArray<HashMap<String, KeyBlob>>();
33    }
34
35    public KeyStore createMock() {
36        KeyStore mock = mock(KeyStore.class);
37        when(mock.state()).thenReturn(KeyStore.State.UNLOCKED);
38
39        when(mock.put(anyString(), Matchers.any(byte[].class), anyInt(), anyInt()))
40                .thenAnswer(new Answer<Boolean>() {
41
42                    @Override
43                    public Boolean answer(InvocationOnMock invocation) throws Throwable {
44                        Object[] args = invocation.getArguments();
45                        return put((String) args[0], (byte[]) args[1], (Integer) args[2],
46                                (Integer) args[3]);
47                    }
48                });
49
50        when(mock.importKey(anyString(), Matchers.any(byte[].class), anyInt(), anyInt()))
51                .thenAnswer(new Answer<Boolean>() {
52
53                    @Override
54                    public Boolean answer(InvocationOnMock invocation) throws Throwable {
55                        Object[] args = invocation.getArguments();
56                        return importKey((String) args[0], (byte[]) args[1], (Integer) args[2],
57                                (Integer) args[3]);
58                    }
59                });
60
61        when(mock.delete(anyString(), anyInt())).thenAnswer(new Answer<Boolean>() {
62
63            @Override
64            public Boolean answer(InvocationOnMock invocation) throws Throwable {
65                Object[] args = invocation.getArguments();
66                return delete((String) args[0], (Integer) args[1]);
67            }
68        });
69
70        when(mock.contains(anyString(), anyInt())).thenAnswer(new Answer<Boolean>() {
71
72            @Override
73            public Boolean answer(InvocationOnMock invocation) throws Throwable {
74                Object[] args = invocation.getArguments();
75                return contains((String) args[0], (Integer) args[1]);
76            }
77        });
78
79        when(mock.duplicate(anyString(), anyInt(), anyString(), anyInt()))
80                .thenAnswer(new Answer<Boolean>() {
81                    @Override
82                    public Boolean answer(InvocationOnMock invocation) throws Throwable {
83                        Object[] args = invocation.getArguments();
84                        return duplicate((String) args[0], (Integer) args[1], (String) args[2],
85                                (Integer) args[3]);
86                    }
87                });
88        return mock;
89    }
90
91    private KeyBlob access(int uid, String key, boolean createIfNotExist) {
92        if (mStore.get(uid) == null) {
93            mStore.put(uid, new HashMap<String, KeyBlob>());
94        }
95        HashMap<String, KeyBlob> map = mStore.get(uid);
96        if (map.containsKey(key)) {
97            return map.get(key);
98        } else {
99            if (createIfNotExist) {
100                KeyBlob blob = new KeyBlob();
101                map.put(key, blob);
102                return blob;
103            } else {
104                return null;
105            }
106        }
107    }
108
109    public KeyBlob getKeyBlob(int uid, String key) {
110        return access(uid, key, false);
111    }
112
113    private boolean put(String key, byte[] value, int uid, int flags) {
114        access(uid, key, true).update(value,  flags);
115        return true;
116    }
117
118    private boolean importKey(String keyName, byte[] key, int uid, int flags) {
119        return put(keyName, key, uid, flags);
120    }
121
122    private boolean delete(String key, int uid) {
123        if (mStore.get(uid) != null) {
124            mStore.get(uid).remove(key);
125        }
126        return true;
127    }
128
129    private boolean contains(String key, int uid) {
130        return access(uid, key, false) != null;
131    }
132
133    private boolean duplicate(String srcKey, int srcUid, String destKey, int destUid) {
134        for (int i = 0; i < mStore.size(); i++) {
135            int key = mStore.keyAt(i);
136            // Cannot copy to itself
137            if (srcKey.equals(destKey) && key == destUid) {
138                continue;
139            }
140            if (srcUid == -1 || srcUid == key) {
141                HashMap<String, KeyBlob> map = mStore.get(key);
142                if (map.containsKey(srcKey)) {
143                    KeyBlob blob = map.get(srcKey);
144                    access(destUid, destKey, true).update(blob.blob, blob.flag);
145                    break;
146                }
147            }
148        }
149        return true;
150    }
151
152    @Override
153    public String toString() {
154        StringBuilder sb = new StringBuilder();
155        sb.append("KeyStore {");
156        for (int i = 0; i < mStore.size(); i++) {
157            int uid = mStore.keyAt(i);
158            for (String keyName : mStore.get(uid).keySet()) {
159                sb.append(String.format("%d:%s, ", uid, keyName));
160            }
161        }
162        sb.append("}");
163        return sb.toString();
164    }
165}