1package com.android.hotspot2.est;
2
3import android.net.Network;
4import android.util.Base64;
5import android.util.Log;
6
7import com.android.hotspot2.OMADMAdapter;
8import com.android.hotspot2.asn1.Asn1Class;
9import com.android.hotspot2.asn1.Asn1Constructed;
10import com.android.hotspot2.asn1.Asn1Decoder;
11import com.android.hotspot2.asn1.Asn1ID;
12import com.android.hotspot2.asn1.Asn1Integer;
13import com.android.hotspot2.asn1.Asn1Object;
14import com.android.hotspot2.asn1.Asn1Oid;
15import com.android.hotspot2.asn1.OidMappings;
16import com.android.hotspot2.osu.HTTPHandler;
17import com.android.hotspot2.osu.OSUSocketFactory;
18import com.android.hotspot2.osu.commands.GetCertData;
19import com.android.hotspot2.pps.HomeSP;
20import com.android.hotspot2.utils.HTTPMessage;
21import com.android.hotspot2.utils.HTTPResponse;
22import com.android.org.bouncycastle.asn1.ASN1Encodable;
23import com.android.org.bouncycastle.asn1.ASN1EncodableVector;
24import com.android.org.bouncycastle.asn1.ASN1Set;
25import com.android.org.bouncycastle.asn1.DERBitString;
26import com.android.org.bouncycastle.asn1.DEREncodableVector;
27import com.android.org.bouncycastle.asn1.DERIA5String;
28import com.android.org.bouncycastle.asn1.DERObjectIdentifier;
29import com.android.org.bouncycastle.asn1.DERPrintableString;
30import com.android.org.bouncycastle.asn1.DERSet;
31import com.android.org.bouncycastle.asn1.x509.Attribute;
32import com.android.org.bouncycastle.jce.PKCS10CertificationRequest;
33import com.android.org.bouncycastle.jce.spec.ECNamedCurveGenParameterSpec;
34
35import java.io.ByteArrayInputStream;
36import java.io.IOException;
37import java.net.URL;
38import java.nio.ByteBuffer;
39import java.nio.charset.StandardCharsets;
40import java.security.AlgorithmParameters;
41import java.security.GeneralSecurityException;
42import java.security.KeyPair;
43import java.security.KeyPairGenerator;
44import java.security.KeyStore;
45import java.security.PrivateKey;
46import java.security.cert.CertificateFactory;
47import java.security.cert.X509Certificate;
48import java.util.ArrayList;
49import java.util.Arrays;
50import java.util.Collection;
51import java.util.HashMap;
52import java.util.HashSet;
53import java.util.Iterator;
54import java.util.List;
55import java.util.Map;
56import java.util.Set;
57
58import javax.net.ssl.KeyManager;
59import javax.security.auth.x500.X500Principal;
60
61//import com.android.org.bouncycastle.jce.provider.BouncyCastleProvider;
62
63public class ESTHandler implements AutoCloseable {
64    private static final String TAG = "HS2EST";
65    private static final int MinRSAKeySize = 2048;
66
67    private static final String CACERT_PATH = "/cacerts";
68    private static final String CSR_PATH = "/csrattrs";
69    private static final String SIMPLE_ENROLL_PATH = "/simpleenroll";
70    private static final String SIMPLE_REENROLL_PATH = "/simplereenroll";
71
72    private final URL mURL;
73    private final String mUser;
74    private final byte[] mPassword;
75    private final OSUSocketFactory mSocketFactory;
76    private final OMADMAdapter mOMADMAdapter;
77
78    private final List<X509Certificate> mCACerts = new ArrayList<>();
79    private final List<X509Certificate> mClientCerts = new ArrayList<>();
80    private PrivateKey mClientKey;
81
82    public ESTHandler(GetCertData certData, Network network, OMADMAdapter omadmAdapter,
83                      KeyManager km, KeyStore ks, HomeSP homeSP, int flowType)
84            throws IOException, GeneralSecurityException {
85        mURL = new URL(certData.getServer());
86        mUser = certData.getUserName();
87        mPassword = certData.getPassword();
88        mSocketFactory = OSUSocketFactory.getSocketFactory(ks, homeSP, flowType,
89                network, mURL, km, true);
90        mOMADMAdapter = omadmAdapter;
91    }
92
93    @Override
94    public void close() throws IOException {
95    }
96
97    public List<X509Certificate> getCACerts() {
98        return mCACerts;
99    }
100
101    public List<X509Certificate> getClientCerts() {
102        return mClientCerts;
103    }
104
105    public PrivateKey getClientKey() {
106        return mClientKey;
107    }
108
109    private static String indent(int amount) {
110        char[] indent = new char[amount * 2];
111        Arrays.fill(indent, ' ');
112        return new String(indent);
113    }
114
115    public void execute(boolean reenroll) throws IOException, GeneralSecurityException {
116        URL caURL = new URL(mURL.getProtocol(), mURL.getHost(), mURL.getPort(),
117                mURL.getFile() + CACERT_PATH);
118
119        HTTPResponse response;
120        try (HTTPHandler httpHandler = new HTTPHandler(StandardCharsets.ISO_8859_1, mSocketFactory,
121                mUser, mPassword)) {
122            response = httpHandler.doGetHTTP(caURL);
123
124            if (!"application/pkcs7-mime".equals(response.getHeaders().
125                    get(HTTPMessage.ContentTypeHeader))) {
126                throw new IOException("Unexpected Content-Type: " +
127                        response.getHeaders().get(HTTPMessage.ContentTypeHeader));
128            }
129            ByteBuffer octetBuffer = response.getBinaryPayload();
130            Collection<Asn1Object> pkcs7Content1 = Asn1Decoder.decode(octetBuffer);
131            for (Asn1Object asn1Object : pkcs7Content1) {
132                Log.d(TAG, "---");
133                Log.d(TAG, asn1Object.toString());
134            }
135            Log.d(TAG, CACERT_PATH);
136
137            mCACerts.addAll(unpackPkcs7(octetBuffer));
138            for (X509Certificate certificate : mCACerts) {
139                Log.d(TAG, "CA-Cert: " + certificate.getSubjectX500Principal());
140            }
141
142            /*
143            byte[] octets = new byte[octetBuffer.remaining()];
144            octetBuffer.duplicate().get(octets);
145            for (byte b : octets) {
146                System.out.printf("%02x ", b & 0xff);
147            }
148            Log.d(TAG, );
149            */
150
151            /* + BC
152            try {
153                byte[] octets = new byte[octetBuffer.remaining()];
154                octetBuffer.duplicate().get(octets);
155                ASN1InputStream asnin = new ASN1InputStream(octets);
156                for (int n = 0; n < 100; n++) {
157                    ASN1Primitive object = asnin.readObject();
158                    if (object == null) {
159                        break;
160                    }
161                    parseObject(object, 0);
162                }
163            }
164            catch (Throwable t) {
165                t.printStackTrace();
166            }
167
168            Collection<Asn1Object> pkcs7Content = Asn1Decoder.decode(octetBuffer);
169            for (Asn1Object asn1Object : pkcs7Content) {
170                Log.d(TAG, asn1Object);
171            }
172
173            if (pkcs7Content.size() != 1) {
174                throw new IOException("Unexpected pkcs 7 container: " + pkcs7Content.size());
175            }
176
177            Asn1Constructed pkcs7Root = (Asn1Constructed) pkcs7Content.iterator().next();
178            Iterator<Asn1ID> certPath = Arrays.asList(Pkcs7CertPath).iterator();
179            Asn1Object certObject = pkcs7Root.findObject(certPath);
180            if (certObject == null || certPath.hasNext()) {
181                throw new IOException("Failed to find cert; returned object " + certObject +
182                        ", path " + (certPath.hasNext() ? "short" : "exhausted"));
183            }
184
185            ByteBuffer certOctets = certObject.getPayload();
186            if (certOctets == null) {
187                throw new IOException("No cert payload in: " + certObject);
188            }
189
190            byte[] certBytes = new byte[certOctets.remaining()];
191            certOctets.get(certBytes);
192
193            CertificateFactory certFactory = CertificateFactory.getInstance("X.509");
194            Certificate cert = certFactory.generateCertificate(new ByteArrayInputStream(certBytes));
195            Log.d(TAG, "EST Cert: " + cert);
196            */
197
198            URL csrURL = new URL(mURL.getProtocol(), mURL.getHost(), mURL.getPort(),
199                    mURL.getFile() + CSR_PATH);
200            response = httpHandler.doGetHTTP(csrURL);
201
202            octetBuffer = response.getBinaryPayload();
203            byte[] csrData = buildCSR(octetBuffer, mOMADMAdapter, httpHandler);
204
205        /**/
206            Collection<Asn1Object> o = Asn1Decoder.decode(ByteBuffer.wrap(csrData));
207            Log.d(TAG, "CSR:");
208            Log.d(TAG, o.iterator().next().toString());
209            Log.d(TAG, "End CSR.");
210        /**/
211
212            URL enrollURL = new URL(mURL.getProtocol(), mURL.getHost(), mURL.getPort(),
213                    mURL.getFile() + (reenroll ? SIMPLE_REENROLL_PATH : SIMPLE_ENROLL_PATH));
214            String data = Base64.encodeToString(csrData, Base64.DEFAULT);
215            octetBuffer = httpHandler.exchangeBinary(enrollURL, data, "application/pkcs10");
216
217            Collection<Asn1Object> pkcs7Content2 = Asn1Decoder.decode(octetBuffer);
218            for (Asn1Object asn1Object : pkcs7Content2) {
219                Log.d(TAG, "---");
220                Log.d(TAG, asn1Object.toString());
221            }
222            mClientCerts.addAll(unpackPkcs7(octetBuffer));
223            for (X509Certificate cert : mClientCerts) {
224                Log.d(TAG, cert.toString());
225            }
226        }
227    }
228
229    private static final Asn1ID sSEQUENCE = new Asn1ID(Asn1Decoder.TAG_SEQ, Asn1Class.Universal);
230    private static final Asn1ID sCTXT0 = new Asn1ID(0, Asn1Class.Context);
231    private static final int PKCS7DataVersion = 1;
232    private static final int PKCS7SignedDataVersion = 3;
233
234    private static List<X509Certificate> unpackPkcs7(ByteBuffer pkcs7)
235            throws IOException, GeneralSecurityException {
236        Collection<Asn1Object> pkcs7Content = Asn1Decoder.decode(pkcs7);
237
238        if (pkcs7Content.size() != 1) {
239            throw new IOException("Unexpected pkcs 7 container: " + pkcs7Content.size());
240        }
241
242        Asn1Object data = pkcs7Content.iterator().next();
243        if (!data.isConstructed() || !data.matches(sSEQUENCE)) {
244            throw new IOException("Expected SEQ OF, got " + data.toSimpleString());
245        } else if (data.getChildren().size() != 2) {
246            throw new IOException("Expected content info to have two children, got " +
247                    data.getChildren().size());
248        }
249
250        Iterator<Asn1Object> children = data.getChildren().iterator();
251        Asn1Object contentType = children.next();
252        if (!contentType.equals(Asn1Oid.PKCS7SignedData)) {
253            throw new IOException("Content not PKCS7 signed data");
254        }
255        Asn1Object content = children.next();
256        if (!content.isConstructed() || !content.matches(sCTXT0)) {
257            throw new IOException("Expected [CONTEXT 0] with one child, got " +
258                    content.toSimpleString() + ", " + content.getChildren().size());
259        }
260
261        Asn1Object signedData = content.getChildren().iterator().next();
262        Map<Integer, Asn1Object> itemMap = new HashMap<>();
263        for (Asn1Object item : signedData.getChildren()) {
264            if (itemMap.put(item.getTag(), item) != null && item.getTag() != Asn1Decoder.TAG_SET) {
265                throw new IOException("Duplicate item in SignedData: " + item.toSimpleString());
266            }
267        }
268
269        Asn1Object versionObject = itemMap.get(Asn1Decoder.TAG_INTEGER);
270        if (versionObject == null || !(versionObject instanceof Asn1Integer)) {
271            throw new IOException("Bad or missing PKCS7 version: " + versionObject);
272        }
273        int pkcs7version = (int) ((Asn1Integer) versionObject).getValue();
274        Asn1Object innerContentInfo = itemMap.get(Asn1Decoder.TAG_SEQ);
275        if (innerContentInfo == null ||
276                !innerContentInfo.isConstructed() ||
277                !innerContentInfo.matches(sSEQUENCE) ||
278                innerContentInfo.getChildren().size() != 1) {
279            throw new IOException("Bad or missing PKCS7 contentInfo");
280        }
281        Asn1Object contentID = innerContentInfo.getChildren().iterator().next();
282        if (pkcs7version == PKCS7DataVersion && !contentID.equals(Asn1Oid.PKCS7Data) ||
283                pkcs7version == PKCS7SignedDataVersion && !contentID.equals(Asn1Oid.PKCS7SignedData)) {
284            throw new IOException("Inner PKCS7 content (" + contentID +
285                    ") not expected for version " + pkcs7version);
286        }
287        Asn1Object certWrapper = itemMap.get(0);
288        if (certWrapper == null || !certWrapper.isConstructed() || !certWrapper.matches(sCTXT0)) {
289            throw new IOException("Expected [CONTEXT 0], got: " + certWrapper);
290        }
291
292        List<X509Certificate> certList = new ArrayList<>(certWrapper.getChildren().size());
293        CertificateFactory certFactory = CertificateFactory.getInstance("X.509");
294        for (Asn1Object certObject : certWrapper.getChildren()) {
295            ByteBuffer certOctets = ((Asn1Constructed) certObject).getEncoding();
296            if (certOctets == null) {
297                throw new IOException("No cert payload in: " + certObject);
298            }
299            byte[] certBytes = new byte[certOctets.remaining()];
300            certOctets.get(certBytes);
301
302            certList.add((X509Certificate) certFactory.
303                    generateCertificate(new ByteArrayInputStream(certBytes)));
304        }
305        return certList;
306    }
307
308    private byte[] buildCSR(ByteBuffer octetBuffer, OMADMAdapter omadmAdapter,
309                            HTTPHandler httpHandler) throws IOException, GeneralSecurityException {
310
311        //Security.addProvider(new BouncyCastleProvider());
312
313        Log.d(TAG, "/csrattrs:");
314        /*
315        byte[] octets = new byte[octetBuffer.remaining()];
316        octetBuffer.duplicate().get(octets);
317        for (byte b : octets) {
318            System.out.printf("%02x ", b & 0xff);
319        }
320        */
321        Collection<Asn1Object> csrs = Asn1Decoder.decode(octetBuffer);
322        for (Asn1Object asn1Object : csrs) {
323            Log.d(TAG, asn1Object.toString());
324        }
325
326        if (csrs.size() != 1) {
327            throw new IOException("Unexpected object count in CSR attributes response: " +
328                    csrs.size());
329        }
330        Asn1Object sequence = csrs.iterator().next();
331        if (sequence.getClass() != Asn1Constructed.class) {
332            throw new IOException("Unexpected CSR attribute container: " + sequence);
333        }
334
335        String keyAlgo = null;
336        Asn1Oid keyAlgoOID = null;
337        String sigAlgo = null;
338        String curveName = null;
339        Asn1Oid pubCrypto = null;
340        int keySize = -1;
341        Map<Asn1Oid, ASN1Encodable> idAttributes = new HashMap<>();
342
343        for (Asn1Object child : sequence.getChildren()) {
344            if (child.getTag() == Asn1Decoder.TAG_OID) {
345                Asn1Oid oid = (Asn1Oid) child;
346                OidMappings.SigEntry sigEntry = OidMappings.getSigEntry(oid);
347                if (sigEntry != null) {
348                    sigAlgo = sigEntry.getSigAlgo();
349                    keyAlgoOID = sigEntry.getKeyAlgo();
350                    keyAlgo = OidMappings.getJCEName(keyAlgoOID);
351                } else if (oid.equals(OidMappings.sPkcs9AtChallengePassword)) {
352                    byte[] tlsUnique = httpHandler.getTLSUnique();
353                    if (tlsUnique != null) {
354                        idAttributes.put(oid, new DERPrintableString(
355                                Base64.encodeToString(tlsUnique, Base64.DEFAULT)));
356                    } else {
357                        Log.w(TAG, "Cannot retrieve TLS unique channel binding");
358                    }
359                }
360            } else if (child.getTag() == Asn1Decoder.TAG_SEQ) {
361                Asn1Oid oid = null;
362                Set<Asn1Oid> oidValues = new HashSet<>();
363                List<Asn1Object> values = new ArrayList<>();
364
365                for (Asn1Object attributeSeq : child.getChildren()) {
366                    if (attributeSeq.getTag() == Asn1Decoder.TAG_OID) {
367                        oid = (Asn1Oid) attributeSeq;
368                    } else if (attributeSeq.getTag() == Asn1Decoder.TAG_SET) {
369                        for (Asn1Object value : attributeSeq.getChildren()) {
370                            if (value.getTag() == Asn1Decoder.TAG_OID) {
371                                oidValues.add((Asn1Oid) value);
372                            } else {
373                                values.add(value);
374                            }
375                        }
376                    }
377                }
378                if (oid == null) {
379                    throw new IOException("Invalid attribute, no OID");
380                }
381                if (oid.equals(OidMappings.sExtensionRequest)) {
382                    for (Asn1Oid subOid : oidValues) {
383                        if (OidMappings.isIDAttribute(subOid)) {
384                            if (subOid.equals(OidMappings.sMAC)) {
385                                idAttributes.put(subOid, new DERIA5String(omadmAdapter.getMAC()));
386                            } else if (subOid.equals(OidMappings.sIMEI)) {
387                                idAttributes.put(subOid, new DERIA5String(omadmAdapter.getImei()));
388                            } else if (subOid.equals(OidMappings.sMEID)) {
389                                idAttributes.put(subOid, new DERBitString(omadmAdapter.getMeid()));
390                            } else if (subOid.equals(OidMappings.sDevID)) {
391                                idAttributes.put(subOid,
392                                        new DERPrintableString(omadmAdapter.getDevID()));
393                            }
394                        }
395                    }
396                } else if (OidMappings.getCryptoID(oid) != null) {
397                    pubCrypto = oid;
398                    if (!values.isEmpty()) {
399                        for (Asn1Object value : values) {
400                            if (value.getTag() == Asn1Decoder.TAG_INTEGER) {
401                                keySize = (int) ((Asn1Integer) value).getValue();
402                            }
403                        }
404                    }
405                    if (oid.equals(OidMappings.sAlgo_EC)) {
406                        if (oidValues.isEmpty()) {
407                            throw new IOException("No ECC curve name provided");
408                        }
409                        for (Asn1Oid value : oidValues) {
410                            curveName = OidMappings.getJCEName(value);
411                            if (curveName != null) {
412                                break;
413                            }
414                        }
415                        if (curveName == null) {
416                            throw new IOException("Found no ECC curve for " + oidValues);
417                        }
418                    }
419                }
420            }
421        }
422
423        if (keyAlgoOID == null) {
424            throw new IOException("No public key algorithm specified");
425        }
426        if (pubCrypto != null && !pubCrypto.equals(keyAlgoOID)) {
427            throw new IOException("Mismatching key algorithms");
428        }
429
430        if (keyAlgoOID.equals(OidMappings.sAlgo_RSA)) {
431            if (keySize < MinRSAKeySize) {
432                if (keySize >= 0) {
433                    Log.i(TAG, "Upgrading suggested RSA key size from " +
434                            keySize + " to " + MinRSAKeySize);
435                }
436                keySize = MinRSAKeySize;
437            }
438        }
439
440        Log.d(TAG, String.format("pub key '%s', signature '%s', ECC curve '%s', id-atts %s",
441                keyAlgo, sigAlgo, curveName, idAttributes));
442
443        /*
444          Ruckus:
445            SEQUENCE:
446              OID=1.2.840.113549.1.1.11 (algo_id_sha256WithRSAEncryption)
447
448          RFC-7030:
449            SEQUENCE:
450              OID=1.2.840.113549.1.9.7 (challengePassword)
451              SEQUENCE:
452                OID=1.2.840.10045.2.1 (algo_id_ecPublicKey)
453                SET:
454                  OID=1.3.132.0.34 (secp384r1)
455              SEQUENCE:
456                OID=1.2.840.113549.1.9.14 (extensionRequest)
457                SET:
458                  OID=1.3.6.1.1.1.1.22 (mac-address)
459              OID=1.2.840.10045.4.3.3 (eccdaWithSHA384)
460
461              1L, 3L, 6L, 1L, 1L, 1L, 1L, 22
462         */
463
464        // ECC Does not appear to be supported currently
465        KeyPairGenerator kpg = KeyPairGenerator.getInstance(keyAlgo);
466        if (curveName != null) {
467            AlgorithmParameters algorithmParameters = AlgorithmParameters.getInstance(keyAlgo);
468            algorithmParameters.init(new ECNamedCurveGenParameterSpec(curveName));
469            kpg.initialize(algorithmParameters
470                    .getParameterSpec(ECNamedCurveGenParameterSpec.class));
471        } else {
472            kpg.initialize(keySize);
473        }
474        KeyPair kp = kpg.generateKeyPair();
475
476        X500Principal subject = new X500Principal("CN=Android, O=Google, C=US");
477
478        mClientKey = kp.getPrivate();
479
480        // !!! Map the idAttributes into an ASN1Set of values to pass to
481        // the PKCS10CertificationRequest - this code is using outdated BC classes and
482        // has *not* been tested.
483        ASN1Set attributes;
484        if (!idAttributes.isEmpty()) {
485            ASN1EncodableVector payload = new DEREncodableVector();
486            for (Map.Entry<Asn1Oid, ASN1Encodable> entry : idAttributes.entrySet()) {
487                DERObjectIdentifier type = new DERObjectIdentifier(entry.getKey().toOIDString());
488                ASN1Set values = new DERSet(entry.getValue());
489                Attribute attribute = new Attribute(type, values);
490                payload.add(attribute);
491            }
492            attributes = new DERSet(payload);
493        } else {
494            attributes = null;
495        }
496
497        return new PKCS10CertificationRequest(sigAlgo, subject, kp.getPublic(),
498                attributes, mClientKey).getEncoded();
499    }
500}
501