1package com.android.hotspot2.osu;
2
3import android.util.Log;
4
5import java.io.IOException;
6import java.net.Socket;
7import java.security.GeneralSecurityException;
8import java.security.KeyStore;
9import java.security.KeyStoreException;
10import java.security.Principal;
11import java.security.PrivateKey;
12import java.security.cert.Certificate;
13import java.security.cert.X509Certificate;
14import java.util.ArrayList;
15import java.util.Collections;
16import java.util.Enumeration;
17import java.util.HashMap;
18import java.util.HashSet;
19import java.util.List;
20import java.util.Map;
21import java.util.Set;
22
23import javax.net.ssl.X509KeyManager;
24import javax.security.auth.x500.X500Principal;
25
26public class WiFiKeyManager implements X509KeyManager {
27    private final KeyStore mKeyStore;
28    private final Map<X500Principal, String[]> mAliases = new HashMap<>();
29
30    public WiFiKeyManager(KeyStore keyStore) throws IOException {
31        mKeyStore = keyStore;
32    }
33
34    public void enableClientAuth(List<String> issuerNames) throws GeneralSecurityException,
35            IOException {
36
37        Set<X500Principal> acceptedIssuers = new HashSet<>();
38        for (String issuerName : issuerNames) {
39            acceptedIssuers.add(new X500Principal(issuerName));
40        }
41
42        Enumeration<String> aliases = mKeyStore.aliases();
43        while (aliases.hasMoreElements()) {
44            String alias = aliases.nextElement();
45            Certificate cert = mKeyStore.getCertificate(alias);
46            if ((cert instanceof X509Certificate) && mKeyStore.getKey(alias, null) != null) {
47                X509Certificate x509Certificate = (X509Certificate) cert;
48                X500Principal issuer = x509Certificate.getIssuerX500Principal();
49                if (acceptedIssuers.contains(issuer)) {
50                    mAliases.put(issuer, new String[]{alias, cert.getPublicKey().getAlgorithm()});
51                }
52            }
53        }
54
55        if (mAliases.isEmpty()) {
56            throw new IOException("No aliases match requested issuers: " + issuerNames);
57        }
58    }
59
60    private static class AliasEntry implements Comparable<AliasEntry> {
61        private final int mPreference;
62        private final String mAlias;
63
64        private AliasEntry(int preference, String alias) {
65            mPreference = preference;
66            mAlias = alias;
67        }
68
69        public int getPreference() {
70            return mPreference;
71        }
72
73        public String getAlias() {
74            return mAlias;
75        }
76
77        @Override
78        public int compareTo(AliasEntry other) {
79            return Integer.compare(getPreference(), other.getPreference());
80        }
81    }
82
83    @Override
84    public String chooseClientAlias(String[] keyTypes, Principal[] issuers, Socket socket) {
85
86        Map<String, Integer> keyPrefs = new HashMap<>(keyTypes.length);
87        int pref = 0;
88        for (String keyType : keyTypes) {
89            keyPrefs.put(keyType, pref++);
90        }
91
92        List<AliasEntry> aliases = new ArrayList<>();
93        if (issuers != null) {
94            for (Principal issuer : issuers) {
95                if (issuer instanceof X500Principal) {
96                    String[] aliasAndKey = mAliases.get((X500Principal) issuer);
97                    if (aliasAndKey != null) {
98                        Integer preference = keyPrefs.get(aliasAndKey[1]);
99                        if (preference != null) {
100                            aliases.add(new AliasEntry(preference, aliasAndKey[0]));
101                        }
102                    }
103                }
104            }
105        } else {
106            for (String[] aliasAndKey : mAliases.values()) {
107                Integer preference = keyPrefs.get(aliasAndKey[1]);
108                if (preference != null) {
109                    aliases.add(new AliasEntry(preference, aliasAndKey[0]));
110                }
111            }
112        }
113        Collections.sort(aliases);
114        return aliases.isEmpty() ? null : aliases.get(0).getAlias();
115    }
116
117    @Override
118    public String[] getClientAliases(String keyType, Principal[] issuers) {
119        List<String> aliases = new ArrayList<>();
120        if (issuers != null) {
121            for (Principal issuer : issuers) {
122                if (issuer instanceof X500Principal) {
123                    String[] aliasAndKey = mAliases.get((X500Principal) issuer);
124                    if (aliasAndKey != null) {
125                        aliases.add(aliasAndKey[0]);
126                    }
127                }
128            }
129        } else {
130            for (String[] aliasAndKey : mAliases.values()) {
131                aliases.add(aliasAndKey[0]);
132            }
133        }
134        return aliases.isEmpty() ? null : aliases.toArray(new String[aliases.size()]);
135    }
136
137    @Override
138    public String chooseServerAlias(String keyType, Principal[] issuers, Socket socket) {
139        throw new UnsupportedOperationException();
140    }
141
142    @Override
143    public String[] getServerAliases(String keyType, Principal[] issuers) {
144        throw new UnsupportedOperationException();
145    }
146
147    @Override
148    public X509Certificate[] getCertificateChain(String alias) {
149        try {
150            List<X509Certificate> certs = new ArrayList<>();
151            for (Certificate certificate : mKeyStore.getCertificateChain(alias)) {
152                if (certificate instanceof X509Certificate) {
153                    certs.add((X509Certificate) certificate);
154                }
155            }
156            return certs.toArray(new X509Certificate[certs.size()]);
157        } catch (KeyStoreException kse) {
158            Log.w(OSUManager.TAG, "Failed to retrieve certificates: " + kse);
159            return null;
160        }
161    }
162
163    @Override
164    public PrivateKey getPrivateKey(String alias) {
165        try {
166            return (PrivateKey) mKeyStore.getKey(alias, null);
167        } catch (GeneralSecurityException gse) {
168            Log.w(OSUManager.TAG, "Failed to retrieve private key: " + gse);
169            return null;
170        }
171    }
172}
173